当咱们训练一个deep learning模型时,怎么样判断当前是过拟合,仍是欠拟合等状态呢?实践中,咱们经常会将数据集分为三部分:train、validation、test。训练过程当中,咱们让模型尽力拟合train数据集,在validation数据集上测试拟合程度。当训练过程结束后,咱们在test集上测试模型最终效果。有经验的炼丹师每每会经过模型在train和validation上的表现,来判断当前是不是过拟合,是不是欠拟合。这个时候,TensorBoard就派上了大用场!python
有没有觉的一目了然呢?我强烈推荐你们使用TensorBoard,使用后炼丹功力显著提高!api
下面,我来说一下如何使用TensorBoard。要使用,也要优雅!
若是你喜欢本身梳理知识,本身尝试,那么不妨阅读官方文档:戳这里查看官方文档
否则的话,就随着老夫玩转TensorBoard吧 ^0^浏览器
熟悉一个新知识的时候,应该将没必要要的东西最精简化,将注意力集中到咱们最关注的地方,因此,我写了一个最简单的模型,在这个模型的基础上对TensorBoard进行探索。网络
首先看一下这个极简的线性模型:session
import tensorflow as tf import random class Model(object): def __init__(self): self.input_x = tf.placeholder(dtype=tf.float32, shape=[None, ], name='x') self.input_y = tf.placeholder(dtype=tf.float32, shape=[None, ], name='y') W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), dtype=tf.float32) b = tf.Variable(tf.random_uniform([1], -1.0, 1.0), dtype=tf.float32) y_predict = self.input_x * W + b self.loss = tf.reduce_sum(tf.abs(y_predict - self.input_y))
相信这个模型你们很快就能看懂,因此就很少说了。接下来看构造数据的代码:app
x_all = [] y_all = [] random.seed(10) for i in range(3000): x = random.random() y = 0.3 * x + 0.1 + random.random() x_all.append(x) y_all.append(y) x_all = np.array(x_all) y_all = np.array(y_all) shuffle_indices = np.random.permutation(np.arange(len(x_all))) x_shuffled = x_all[shuffle_indices] y_shuffled = y_all[shuffle_indices] bound = int(len(x_all) / 10 * 7) x_train = x_shuffled[:bound] y_train = y_shuffled[:bound] x_val = x_shuffled[bound:] y_val = y_shuffled[bound:]
这段代码里作了三件事:dom
下面是对数据按batch取出:ide
def batch_iter(data, batch_size, num_epochs, shuffle=True): """ Generates a batch iterator for a dataset. """ data = np.array(data) data_size = len(data) num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 for epoch in range(num_epochs): # Shuffle the data at each epoch if shuffle: shuffle_indices = np.random.permutation(np.arange(data_size)) shuffled_data = data[shuffle_indices] else: shuffled_data = data for batch_num in range(num_batches_per_epoch): start_index = batch_num * batch_size end_index = min((batch_num + 1) * batch_size, data_size) yield shuffled_data[start_index:end_index]
而后就到了比较本篇博客的核心部分:
首先我来描述一下关键的函数(大部分同窗心里必定是拒绝的 2333,因此建议先看下面的代码,而后再反过头来看函数的介绍):函数
tf.summary.scalar(name, tensor, collections=None, family=None),调用这个函数来观察Tensorflow的Graph中某个节点测试
tf.summary.merge(inputs, collections=None, name=None)
tf.summary.FileWriter,在给定的目录中建立一个事件文件(event file),将summraies保存到该文件夹中。
__init__(logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None, filename_suffix=None)
add_summary(summary, global_step=None)
with tf.Graph().as_default(): sess = tf.Session() with sess.as_default(): m = model.Model() global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(1e-2) grads_and_vars = optimizer.compute_gradients(m.loss) train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars, global_step=global_step) loss_summary = tf.summary.scalar('loss', m.loss) train_summary_op = tf.summary.merge([loss_summary]) train_summary_writer = tf.summary.FileWriter('./summary/train', sess.graph) dev_summary_op = tf.summary.merge([loss_summary]) dev_summary_writer = tf.summary.FileWriter('./summary/dev', sess.graph) def train_step(x_batch, y_batch): feed_dict = {m.input_x: x_batch, m.input_y: y_batch} _, step, summaries, loss = sess.run( [train_op, global_step, train_summary_op, m.loss], feed_dict) train_summary_writer.add_summary(summaries, step) def dev_step(x_batch, y_batch): feed_dict = {m.input_x: x_batch, m.input_y: y_batch} step, summaries, loss = sess.run( [global_step, dev_summary_op, m.loss], feed_dict) dev_summary_writer.add_summary(summaries, step) sess.run(tf.global_variables_initializer()) batches = batch_iter(list(zip(x_train, y_train)), 100, 100) for batch in batches: x_batch, y_batch = zip(*batch) train_step(x_batch, y_batch) current_step = tf.train.global_step(sess, global_step) if current_step % 3 == 0: print('\nEvaluation:') dev_step(x_val, y_val)
如今咱们就可使用TensorBoard查看训练过程了~~
在terminal中输入以下命令:
tensorboard --logdir=summary
TensorBoard 0.4.0rc3 at http://liudaoxing-Lenovo-Rescuer-15ISK:6006 (Press CTRL+C to quit)
没错!这就是咱们train和validation过程当中loss的状况。
点击GRAPHS,就能够看到网络的结构
麻雀虽小,五脏俱全。但愿你们有收获~