Keras 提供 Callback 接口来追踪训练过程当中的每一步结果,包括每个 batch 和每个 epoch。虽然名为“回调函数”,但实际上想要扩展这功能须要继承 keras.callbacks.Callback
类,该类提供两个与模型训练过程相关的属性:html
params
:compile 模型时设定的参数;model
:模型对象。经过这一接口能够实时可视化 fit
过程当中每个 batch 和每个 epoch 迭代过程当中的偏差大小变化。以《Neural Networks and Deep Learning - Chap3 Improving the way neural networks learn》为例,假设咱们要训练一个最简单的神经网络:markdown
这个只有一个神经元的神经网络只有一个权重 w
和一个偏置 b
两个待训练的参数,假设要训练的数据只有 (1, 0)
,在这里比较 MSE 和 Cross Entropy 两种代价函数的学习效果。网络
首先构建这个模型:app
from keras import Sequential, initializers, optimizers from keras.layers import Activation, Dense import numpy as np def viz_keras_fit(w, b, runtime_plot=False, loss="mean_squared_error", act="sigmoid"): d = DrawCallback(runtime_plot=runtime_plot) # 初始化参数 w = initializers.Constant([w]) b = initializers.Constant([b]) x_train, y_train = np.array([1]), np.array([0]) model = Sequential() model.add(Dense(1, activation=act, input_shape=(1,), kernel_initializer=w, bias_initializer=b)) # Learning Rate = 0.15 sgd = optimizers.SGD(lr=0.15) model.compile(optimizer=sgd, loss=loss) model.fit(x = x_train, y = y_train, epochs=150, verbose=0, callbacks=[d]) # Callback List return d 复制代码
初始参数仍然是 (2, 2)
换成 Cross Entropy 做为 loss function 以后:函数
虽然实现了实时可视化,但绘图所用的时间可能比一个 epoch 耗时更久,所以先记录每一步的 loss 再绘图会更好一些:oop
实时观察模型的学习状况能够帮助咱们在初期选择损失函数、激活函数、模型结构以及超参数等。如下是 DrawCallback
的实现:学习
import pylab as pl from IPython import display from keras.callbacks import Callback class DrawCallback(Callback): def __init__(self, runtime_plot=True): super().__init__() self.init_loss = None self.runtime_plot = runtime_plot self.xdata = [] self.ydata = [] def _plot(self, epoch=None): epochs = self.params.get("epochs") pl.ylim(0, int(self.init_loss*2)) pl.xlim(0, epochs) pl.plot(self.xdata, self.ydata) pl.xlabel('Epoch {}/{}'.format(epoch or epochs, epochs)) pl.ylabel('Loss {:.4f}'.format(self.ydata[-1])) def _runtime_plot(self, epoch): self._plot(epoch) display.clear_output(wait=True) display.display(pl.gcf()) pl.gcf().clear() def plot(self): self._plot() pl.show() def on_epoch_end(self, epoch, logs = None): logs = logs or {} loss = logs.get("loss") if self.init_loss is None: self.init_loss = loss self.xdata.append(epoch) self.ydata.append(loss) if self.runtime_plot: self._runtime_plot(epoch) 复制代码