SkySeraph 2018html
Email:skyseraph00#163.comjava
本文系“SkySeraph AI 实践到理论系列”第一篇,咱以AI界的HelloWord 经典MNIST数据集为基础,在Android平台,基于TensorFlow,实现CNN的手写数字识别。
Code here~python
训练和评估部分主要目的是生成用于测试用的pb文件,其保存了利用TensorFlow python API构建训练后的网络拓扑结构和参数信息,实现方式有不少种,除了cnn外还可使用rnn,fcnn等。
其中基于cnn的函数也有两套,分别为tf.layers.conv2d和tf.nn.conv2d, tf.layers.conv2d使用tf.nn.conv2d做为后端处理,参数上filters是整数,filter是4维张量。原型以下:
convolutional.py文件
def conv2d(inputs, filters, kernel_size, strides=(1, 1), padding=’valid’, data_format=’channels_last’,
dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=None,
bias_initializer=init_ops.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None,
activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True, name=None,
reuse=None)android
gen_nn_ops.py 文件git
def conv2d(input, filter, strides, padding, use_cudnn_on_gpu=True, data_format="NHWC", name=None)
官方Demo实例中使用的是layers module,结构以下:github
核心代码在cnn_model_fn(features, labels, mode)函数中,完成卷积结构的完整定义,核心代码以下.算法
也能够采用传统的tf.nn.conv2d函数, 核心代码以下。数据库
导入pb文件.pb文件放assets目录,而后读取后端
String actualFilename = labelFilename.split(“file:///android_asset/“)[1];
Log.i(TAG, “Reading labels from: “ + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.labels.add(line);
}
br.close();网络
TensorFlow接口使用
MNIST,最经典的机器学习模型之一,包含0~9的数字,28*28大小的单色灰度手写数字图片数据库,其中共60,000 training examples和10,000 test examples。
文件目录以下,主要包括4个二进制文件,分别为训练和测试图片及Label。
以下为训练图片的二进制结构,在真实数据前(pixel),有部分描述字段(魔数,图片个数,图片行数和列数),真实数据的存储采用大端规则。
(大端规则,就是数据的高字节保存在低内存地址中,低字节保存在高内存地址中)
在具体实验使用,须要提取真实数据,可采用专门用于处理字节的库struct中的unpack_from方法,核心方法以下:
struct.unpack_from(self._fourBytes2, buf, index)
MNIST做为AI的Hello World入门实例数据,TensorFlow封装对其封装好了函数,可直接使用
mnist = input_data.read_data_sets(‘MNIST’, one_hot=True)
神经网络。一个由大量神经元(neurons)组成的系统,以下图所示[21]
其中x表示输入向量,w为权重,b为偏值bias,f为激活函数。
Activation Function 激活函数: 经常使用的非线性激活函数有Sigmoid、tanh、ReLU等等,公式以下如所示。
机器学习有监督学习(supervised learning)中两大算法分别是分类算法和回归算法,分类算法用于离散型分布预测,回归算法用于连续型分布预测。
回归的目的就是创建一个回归方程用来预测目标值,回归的求解就是求这个回归方程的回归系数。
其中回归(Regression)算法包括Linear Regression,Logistic Regression等, Softmax Regression是其中一种用于解决多分类(multi-class classification)问题的Logistic回归算法的推广,经典实例就是在MNIST手写数字分类上的应用。
Linear Regression是机器学习中最基础的模型,其目标是用预测结果尽量地拟合目标label
MNIST
Softmax
CNN
TensorFlow+CNN / TensorFlow+Android
By SkySeraph-2018
SkySeraph cnBlogs