训练和评估部分主要目的是生成用于测试用的pb文件,其保存了利用TensorFlow python API构建训练后的网络拓扑结构和参数信息,实现方式有不少种,除了cnn外还能够使用rnn,fcnn等。
其中基于cnn的函数也有两套,分别为tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d做为后端处理,参数上filters是整数,filter是4维张量。原型以下:java
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)python
def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)
官方Demo实例中使用的是layers module,结构以下:android
核心代码在cnn_model_fn(features, labels, mode)函数中,完成卷积结构的完整定义,核心代码以下:
git
也能够采用传统的tf.nn.conv2d函数, 核心代码以下:
后端
String actualFilename = labelFilename.split(“file:///android_asset/“)[1]; Log.i(TAG, “Reading labels from: “ + actualFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename))); String line; while ((line = br.readLine()) != null) { c.labels.add(line); } br.close();