最近研究了下如何使用tensorflow进行finetuning,相比于caffe,tensorflow的finetuning麻烦一些,记录以下:网络
finetuning原理很简单,利用一个在数据A集上已训练好的模型做为初始值,改变其部分结构,在另外一数据集B上(采用小学习率)训练的过程叫作finetuning。学习
通常来说,符合以下状况会采用finetuningspa
在数据集A上训练的时候,和普通的tensorflow训练过程彻底一致。可是在数据集B上进行finetuning时,须要先从以前训练好的checkpoint中恢复模型参数,这个地方比较关键,rest
须要注意只恢复须要恢复的参数,其余参数不要恢复,不然会由于找不到的声明而报错。以mnist为例子,若是我想先训练一个0-7的8类分类器,网络结构以下:code
conv1-conv2-fc8(其余不带权重的pooling、softmaxloss层忽略)blog
而后我想用这个训练出的模型参数,在0-9的10类分类器上作finetuning,网络结构以下:ip
conv1-conv2-fc10get
那么在从checkpoint中恢复模型参数时,我只能恢复conv1-conv2,若是连fc8都恢复了,就会由于找不到fc8的定义而报错it
以上描述对应的代码以下:io
1 if tf.train.latest_checkpoint('ckpts') is not None: 2 trainable_vars = tf.trainable_variables() 3 res_vars = [t for t in trainable_vars if t.name.startswith('conv')] 4 saver = tf.train.Saver(var_list=res_vars) 5 saver.restore(sess, tf.train.latest_checkpoint('ckpts')) 6 else: 7 saver = tf.train.Saver()
利用mnist写了一个简单的finetuning例子,你们能够试试,事实证实,利用一个相关的已有模型作finuetuning比从0开始训练收敛的更快而且收敛到的准确率更高,