tensorflow中batch normalization的用法

网上找了下tensorflow中使用batch normalization的博客,发现写的都不是很好,在此总结下:网络

1.原理学习

公式以下:测试

y=γ(x-μ)/σ+βspa

其中x是输入,y是输出,μ是均值,σ是方差,γ和β是缩放(scale)、偏移(offset)系数。code

通常来说,这些参数都是基于channel来作的,好比输入x是一个16*32*32*128(NWHC格式)的feature map,那么上述参数都是128维的向量。其中γ和β是无关紧要的,有的话,就是一个能够学习的参数(参与前向后向),没有的话,就简化成y=(x-μ)/σ。而μ和σ,在训练的时候,使用的是batch内的统计值,测试/预测的时候,采用的是训练时计算出的滑动平均值。orm

 

2.tensorflow中使用blog

tensorflow中batch normalization的实现主要有下面三个:ip

tf.nn.batch_normalizationci

tf.layers.batch_normalization字符串

tf.contrib.layers.batch_norm

封装程度逐个递进,建议使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,由于在tensorflow官网的解释比较详细。我平时多使用tf.layers.batch_normalization,所以下面的步骤都是基于这个。

 

3.训练

训练的时候须要注意两点,(1)输入参数training=True,(2)计算loss时,要添加如下代码(即添加update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)

 

4.测试

测试时须要注意一点,输入参数training=False,其余就没了

 

5.预测

预测时比较特别,由于这一步通常都是从checkpoint文件中读取模型参数,而后作预测。通常来讲,保存checkpoint的时候,不会把全部模型参数都保存下来,由于一些无关数据会增大模型的尺寸,常见的方法是只保存那些训练时更新的参数(可训练参数),以下:

var_list = tf.trainable_variables() saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

但使用了batch_normalization,γ和β是可训练参数没错,μ和σ不是,它们仅仅是经过滑动平均计算出的,若是按照上面的方法保存模型,在读取模型预测时,会报错找不到μ和σ。更诡异的是,利用tf.moving_average_variables()也无法获取bn层中的μ和σ(也多是我用法不对),不过好在全部的参数都在tf.global_variables()中,所以能够这么写:

var_list = tf.trainable_variables() g_list = tf.global_variables() bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] var_list += bn_moving_vars saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述写法,便可把μ和σ保存下来,读取模型预测时也不会报错,固然输入参数training=False仍是要的。

注意上面有个不严谨的地方,由于个人网络结构中只有bn层包含moving_mean和moving_variance,所以只根据这两个字符串作了过滤,若是你的网络结构中其余层也有这两个参数,但你不须要保存,建议使用诸如bn/moving_mean的字符串进行过滤。

 

2018.4.22更新

提供一个基于mnist的示例,供你们参考。包含两个文件,分别用于train/test。注意bn_train.py文件的51-61行,仅保存了网络中的可训练变量和bn层利用统计获得的mean和var。注意示例中须要下载mnist数据集,要保持电脑能够联网。

相关文章
相关标签/搜索