tensorflow笔记(四)之MNIST手写识别系列一

tensorflow笔记(四)之MNIST手写识别系列一
html

版权声明:本文为博主原创文章,转载请指明转载地址python

http://www.cnblogs.com/fydeblog/p/7436310.html
git

前言

这篇博客将利用神经网络去训练MNIST数据集,经过学习到的模型去分类手写数字。github

我会将本篇博客的jupyter notebook放在最后,方便你下载在线调试!推荐结合官方的tensorflow教程来看这个notebook!ubuntu

1. MNIST数据集的导入

这里介绍一下MNIST,MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.网络

首先咱们要导入MNIST数据集,这里须要用到一个input_data.py文件,在你安装tensorflow的examples/tutorials/MNIST目录下,若是tensorflow的目录下没有这个文件夹(通常是你的tensorflow版本不够新,1.2版本有的),还请本身导入或者更新一下tensorflow的版本,导入的方法是在tensorflow的github(https://github.com/tensorflow/tensorflow/tree/master/tensorflow  )下下载examples文件夹,粘贴到tensorflow的根目录下。更新tensorflow版本的话,请在ubuntu终端下运行pip install --upgrade tensorflow就能够了dom

好了,咱们仍是一步步来进行整个过程机器学习

首先咱们先导入咱们须要用到的模块函数

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist  import  input_data

而后咱们用input_data模块导入MNIST数据集学习

mnist = input_data.read_data_sets('MNIST_data',one_hot = True)

上面总共下载了四个压缩文件,内容分别以下:
train-images-idx3-ubyte.gz    训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz      训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz   测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz      测试集图片对应的数字标签

图片数据将被解压成2维的tensor:[image index, pixel index] 其中每一项表示某一图片中特定像素的强度值, 范围从 [0, 255] 到 [-0.5, 0.5]。 "image index"表明数据集中图片的编号, 从0到数据集的上限值。"pixel index"表明该图片中像素点得个数, 从0到图片的像素上限值。

以train-*开头的文件中包括60000个样本,其中分割出55000个样本做为训练集,其他的5000个样本做为验证集。由于全部数据集中28x28像素的灰度图片的尺寸为784,因此训练集输出的tensor格式为[55000, 784]

执行read_data_sets()函数将会返回一个DataSet实例,其中包含了如下三个数据集。 数据集 目的 data_sets.train 55000 组 图片和标签, 用于训练。 data_sets.validation 5000 组 图片和标签, 用于迭代验证训练的准确性。 data_sets.test 10000 组 图片和标签, 用于最终测试训练的准确性。

具体的MNIST数据集的解压和重构咱们能够不了解,会用这个数据集就能够了。(固然别问我这个东西,这个过程我也不知道,嘿嘿)

这里说一下上述代码中的one_hot,MNIST的标签数据是"one-hot vectors"。 一个one-hot向量除了某一位的数字是1之外其他各维度数字都是0。因此在此教程中,数字n将表示成一个只有在第n维度(从0开始)数字为1的10维向量。好比,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0])。

2.实践

咱们首先定义两个占位符,来表示训练数据及其相应标签数据,将会在训练部分进行feed进去

xs = tf.placeholder(tf.float32,[None,784]) # 784 = 28X28
ys = tf.placeholder(tf.float32,[None,10]) # 10 = (0~9) one_hot

如今咱们再来定义神经网络的权重和误差

Weights = tf.Variable(tf.random_normal([784,10]))
biases = tf.Variable(tf.zeros([1,10])+0.2)

先说一下,这个神经网络是输入直接映射到输出,没有隐藏层,输入是每张图像28X28的像素,也就是784,输出是10个长度的向量,也就是10,因此权重是[784,10],误差是[1,10].

y_pre = tf.nn.softmax(tf.matmul(xs,Weights)+biases) 

咱们知道虽然最后的输出结果是10个长度的向量,但他们的值可能不太直观,打个比方,好比都是0.015之类的数,仅仅是打比方哈

为了显示输出结果对每一个数的相应几率,咱们加了一个softmax函数,它的原理很简单,拿10个单位的向量[x0,x1,...,x9]为例,若是想知道数字0的几率是多少,用exp(x0)/(exp(x0)+exp(x1)+...+exp(x9)),其余数字的几率相似推导,你也能够参考我放在博客上的图片,很直观。

cross_entropy =tf.reduce_mean( -tf.reduce_sum(ys*tf.log(y_pre),reduction_indices=[1]))#compute cross_entropy

此次的损失表示形式跟以前都不太同样哈,此次是计算交叉熵,交叉熵是用来衡量咱们的预测用于描述真相的有效性。咱们能够想想,以一张图片为例,y_pre和ys都是一个10个长度的向量,不一样的是y_pre每一个序号对应的值不为0,而ys是one_hot向量,只有一个为1,其他全为0,那么按照上述公式,只有1对应序号i(假如是i)的log(y_pre(i))保留下来了,并且y_pre(i)越大(也就是几率越大),log(y_pre(i))越小(注意计算交叉熵前面有负号的),反之越大,符合咱们对损失的概念。

我试过用官方教程的交叉熵公式,打印交叉熵时出现nan,溢出了,建议用这个好一些

train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

上面是用优化器最小化交叉熵,这里学习率的选取很重要,官方的0.01过小,收敛得慢,还没达到训练损失最小就中止了,结果就是测试集偏差较大,推荐选大点,0.5左右差很少了,再大反而会发散了。

init = tf.global_variables_initializer()

上面是生出初始化init

sess  = tf.Session()

创建一个会话

sess.run(init)

初始化变量

for i in xrange(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train,feed_dict={xs:batch_xs,ys:batch_ys})
    if i %50==0:
        print sess.run(cross_entropy,feed_dict={xs:batch_xs,ys:batch_ys})

上面是程序训练过程,这里说一下xrange和range的区别,它们两个的用法基本相同,但返回的类型不一样,xrange返回的是生成器,range返回的是列表,全部xrange更节省内存,推荐用xrange,python3当中已经没有xrange了,只有range,但它的功能和python2当中的xrange同样

下面咱们来计算计算精度

correct_prediction = tf.equal(tf.argmax(ys,1), tf.argmax(y_pre,1))

tf.argmax 是一个很是有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。tf.argmax(y_pre,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(ys,1) 表明正确的标签,咱们能够用 tf.equal 来检测咱们的预测是否真实标签匹配,这行代码返回的是匹配的布尔值,成功1,失败0

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

tf.cast将布尔类型的correct_prediction转化成float型,而后取平均获得精确度

print sess.run(accuracy, feed_dict={xs: mnist.test.images, ys: mnist.test.labels})

精确度87.79%,官方说的91%我是没达到过,我训练最高不超过89%。

3.结尾

但愿这篇博客能对你的学习有所帮助,谢谢观看!同时,有兴趣的朋友能够多改改参数试试不一样的结果,好比学习率,batch_size等等,这对你的理解也是有帮助的!

下一篇笔记将写用cnn去分类MNIST数据集,敬请期待!

连接: https://pan.baidu.com/s/1oWXk2Iai5f7I4U411XP8hQ

相关文章
相关标签/搜索