tensorflow中moving average的用法

通常在保存模型参数的时候,都会保存一份moving average,是取了不一样迭代次数模型的移动平均,移动平均后的模型每每在性能上会比最后一次迭代保存的模型要好一些。html

tensorflow-models项目中tutorials下cifar中相关的代码写的有点问题,在这写下我本身的作法:git

 

1.构建训练模型时,添加以下代码github

1 variable_averages = tf.train.ExponentialMovingAverage(0.999, global_step) 2 variables_averages_op = variable_averages.apply(tf.trainable_variables()) 3 ave_vars = [variable_averages.average(var) for var in tf.trainable_variables()] 4 train_op = tf.group(train_op, variables_averages_op)

第1行建立了一个指数移动平均类 variable_averagessession

第2行将variable_averages做用于当前模型中全部可训练的变量上,获得 variables_averages_op操做符app

第3行得到全部可训练变量对应的移动平均变量列表集合,后续用于保存模型性能

第4行在原有的训练操做符基础上,再添加variables_averages_op操做符,后续session执行run的时候,除了训练时前向后向,梯度更新,还会对相应的变量作移动平均测试

 

2.开始训练前,建立saver时,使用以下代码spa

1 save_vars = tf.trainable_variables() + ave_vars
2 saver = tf.train.Saver(var_list=save_vars, max_to_keep=5)

第1行获取全部须要保存的变量列表,这个时候 ave_vars就派上用场了。rest

第2行建立saver,指定var_list为全部可训练变量及其对应的移动平均变量。code

另外须要注意的是,若是你的模型中有bn或者相似层,包含有统计参数(均值、方差等),这些不属于可训练参数,还须要额外添加进save_vars中,能够参考个人这篇博客

 

3.在作inference的时候,利用以下代码从checkpoint中恢复出移动平均模型

1 variable_averages = tf.train.ExponentialMovingAverage(0.999) 2 variables_to_restore = variable_averages.variables_to_restore() 3 saver = tf.train.Saver(variables_to_restore) 4 saver.restore(sess, model_path)

这几行很简单,就不作解释了。

实际上,在inference的时候,刚刚的作法除了能够从checkpoint文件中恢复出移动平均参数,还能够恢复出对应迭代的模型参数,能够用来对比两种方式,哪一种效果更好,这时只须要将上面代码的第3行改成saver = tf.train.Saver(tf.trainable_variables())便可(和保存时相同,若是有bn,也须要额外考虑)。在个人测试中,使用移动平均参数效果更佳。

相关文章
相关标签/搜索