一:数据集数组
采用MNIST数据集:--》官网网络
数据集被分红两部分:60000行的训练数据集和10000行的测试数据集。函数
其中每一张图片包含28*28个像素,咱们把这个数组展开成一个向量,长度为28*28=784.在MNIST训练数据集中mnist.train.images是一个形状为[60000,784]的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。图片里的某个像素的强度值介于0-1之间。测试
MNIST数据集的标签是介于0-9的数字,咱们把便签转化为‘one-hot vectors’.一个one-hot向量除了某一位数字1之外,其他维度数字都是0.好比标签0将表示为([1,0,0,0,0,0,0,0,0,0,0]),标签3表示为([0,0,0,1,0,0,0,0,0,0]).因此标签至关于[60000,10]的数字矩阵。spa
咱们的结果是0-9,咱们的模型可能推测出一张图片是数字9的几率为80%,是数字8的几率为10%,而后其余数字的几率更小,整体几率加起来等于1.这至关于一个使用softmax回归模型的案例。code
下面使用softmax模型来预测:blog
# MNIST数据集 手写数字 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 载入数据集,若是没有下载,程序会自动下载 mnist=input_data.read_data_sets('MNIST_data',one_hot=True) # 每一个批次的大小 batch_size=100 # 计算一共有多少个批次 n_batch=mnist.train.num_examples//batch_size # 定义两个placeholder x=tf.placeholder(tf.float32,[None,784]) y=tf.placeholder(tf.float32,[None,10]) # 建立一个简单的神经网络 W=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) prediction=tf.nn.softmax(tf.matmul(x,W)+b) # 二次代价函数 loss=tf.reduce_mean(tf.square(y-prediction)) # 使用梯度降低法 train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化变量 init=tf.global_variables_initializer() # 求最大值在哪一个位置,结果存放在一个布尔值列表中 correct_prediction=tf.equal(tf.argmax(y,1),tf.arg_max(prediction,1))# argmax返回一维张量中最大值所在的位置 # 求准确率 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # cast做用是将布尔值转换为浮点型。 with tf.Session() as sess: sess.run(init) for epoch in range(21): # 训练20次 for batch in range(n_batch): # 每次喂入必定的数据 batch_xs,batch_ys=mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) #求准确率 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print('Iter:'+str(epoch)+',Testing Accuracy:'+str(acc))
# 结果 # 能够看出每次训练准确率都在提升 Iter:0,Testing Accuracy:0.8301 Iter:1,Testing Accuracy:0.8706 Iter:2,Testing Accuracy:0.8811 Iter:3,Testing Accuracy:0.8883 Iter:4,Testing Accuracy:0.8943 Iter:5,Testing Accuracy:0.8966 Iter:6,Testing Accuracy:0.9002 Iter:7,Testing Accuracy:0.9017 Iter:8,Testing Accuracy:0.9043 Iter:9,Testing Accuracy:0.9052 Iter:10,Testing Accuracy:0.9061 Iter:11,Testing Accuracy:0.9071 Iter:12,Testing Accuracy:0.908 Iter:13,Testing Accuracy:0.9096 Iter:14,Testing Accuracy:0.9094 Iter:15,Testing Accuracy:0.9102 Iter:16,Testing Accuracy:0.9116 Iter:17,Testing Accuracy:0.9119 Iter:18,Testing Accuracy:0.9126 Iter:19,Testing Accuracy:0.9134 Iter:20,Testing Accuracy:0.9136