1.准备数据,使用占位符,动态加载训练数据git
x=tf.placeholder(tf.float32,[None,784]) y_true=tf.placeholder(tf.int32,[None,10])
2.初始化参数,创建模型dom
weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0)) bias=tf.Variable(tf.canstant(0.0,shape=[10])) y_predict=tf.matmul(x,weight)+bias
3.求平均交叉熵损失优化
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
4.梯度降低优化scala
train_op=tf.GradientDescentOptimizer(0.3).minimize(loss)
5.求准确率rest
equal_list=tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1)) accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))
完整代码:code
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os mnist = input_data.read_data_sets('./data/MNISI_data/', one_hot=True) def full_connection(): # 1.准备数据 with tf.variable_scope("data"): x = tf.placeholder(tf.float32, [None, 784]) y_true = tf.placeholder(tf.int32, [None, 10]) # 2.创建模型 with tf.variable_scope('predict_model'): weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) y_predict = tf.matmul(x, weight) + bias # 3.平均交叉熵损失 with tf.variable_scope('loss'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4.梯度降低优化 with tf.variable_scope('optimizer'): train_op = tf.train.GradientDescentOptimizer(0.4).minimize(loss) # 5.求准确率 with tf.variable_scope('acc'): equal_list = tf.equal(tf.arg_max(y_true, 1), tf.arg_max(y_predict, 1)) accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) init_op = tf.initialize_all_variables() # 收集变量,tensorboard使用 tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) tf.summary.histogram('weight', weight) tf.summary.histogram('bias', bias) merged = tf.summary.merge_all() saver = tf.train.Saver() is_train = False with tf.Session() as sess: if is_train == True: sess.run(init_op) fileWriter = tf.summary.FileWriter('./temp/summary/test', graph=sess.graph) if os.path.exists('./temp/ckpt/checkpoint'): # 加载训练的模型 saver.restore(sess, './temp/ckpt/full_conn') for i in range(4000): # 每次批量货期50个数据集 mnist_x, mnist_y = mnist.train.next_batch(50) sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y}) fileWriter.add_summary(summary, i) print("训练低%d步,准确率为:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存训练完的模型 saver.save(sess, './temp/ckpt/full_conn') else: saver.restore(sess, './temp/ckpt/full_conn') for i in range(100): # 每次批量货期1个数据集 x_test, y_test = mnist.test.next_batch(1) print('低%d张图片,手写数字图片目标:%d--%d' % ( i, tf.arg_max(y_test, 1).eval(), tf.arg_max(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval() )) if __name__ == '__main__': full_connection()