使用TensorFlow训练模型的过程当中,须要适时对模型进行保存,以及对保存的模型进行restore,以方便后续对模型进行处理。好比进行测试,或者部署;好比拿别的模型进行fine-tune,等等。固然,直接的保存和restore比较简单,无需多言,可是保存和restore中还牵涉到其余问题,以及针对各类需求的各类参数等,可能不便一下都记好。所以,有必要对此进行一个总结。本文就是对使用TensorFlow保存和restore模型的相关内容进行一下总结,以便备忘。git
保存模型是整个内容的第一步,固然也十分简单。无非是建立一个saver,并在一个Session里完成保存。好比:函数
saver = tf.train.Saver() with tf.Session() as sess: saver.save(sess, model_name)
以上代码在0.11如下版本的TensorFlow里会保存与下面相似的3个文件:测试
checkpointspa
model.ckpt-1000.metarest
model.ckpt-1000.ckptcode
在0.11及以上版本的TensorFlow里则会保存与下相似的4个文件:部署
checkpointget
model.ckpt-1000.indexinput
model.ckpt-1000.data-00000-of-00001it
model.ckpt-1000.meta
其中checkpoint列出保存的全部模型以及最近的模型;meta文件是模型定义的内容;ckpt(或data和index)文件是保存的模型数据;内里细节无需过多关注,若是想了解,stackOverflow上有一个解释的回答。
固然,除了上面最简单的保存方式,也能够指定保存的步数,多长时间保存一次,磁盘上最多保有几个模型(将前面的删除以保持固定个数),以下:
建立saver时指定参数:
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中savable_variables指定待保存的变量,好比指定为tf.global_variables()保存全部global变量;指定为[v1, v2]保存v1和v2两个变量;若是省略,则保存全部;
max_to_keep指定磁盘上最多保有几个模型;keep_checkpoint_every_n_hours指定多少小时保存一次。
保存模型时指定参数:
saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
如上,其中能够指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。
具体来讲,Restore模型的过程能够分为两个部分,首先是建立模型,能够手动建立,也能够从meta文件里加载graph进行建立。
建立模型与训练模型时建立模型的代码相同,能够直接复制过来使用。
从meta文件里进行加载,能够直接在Session里进行以下操做:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
后面的参数直接使用meta文件的路径便可。如此,即将模型定义的graph加载进来了。
固然,还有一点须要注意,并不是全部的TensorFlow模型都能将graph输出到meta文件中或者从meta文件中加载进来,若是模型有部分不能序列化的部分,则此种方法可能会无效。
而后就是为模型加载数据,可使用下面两种方法:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./'))
此方法加载指定文件夹下最近保存的一个模型的数据;或者
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))
此方法能够指定具体某个数据,须要注意的是,指定的文件不要包含后缀。
将模型数据加载进来以后,下一步就是利用加载的模型进行下一步的操做了。这能够根据不一样须要以以下几种方式进行操做。
能够直接查看Restore进来的模型的参数,以下:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) tvs = [v for v in tf.trainable_variables()] for v in tvs: print(v.name) print(sess.run(v))
如名所言,以上是查看模型中的trainable variables;或者咱们也能够查看模型中的全部tensor或者operations,以下:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) gv = [v for v in tf.global_variables()] for v in gv: print(v.name)
上面经过global_variables()得到的与前trainable_variables相似,只是多了一些非trainable的变量,好比定义时指定为trainable=False的变量,或Optimizer相关的变量。
下面则能够得到几乎全部的operations相关的tensor:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = [o for o in sess.graph.get_operations()] for o in ops: print(o.name)
首先,上面的sess.graph.get_operations()能够换为tf.get_default_graph().get_operations(),两者区别无非是graph明确的时候能够直接使用前者,不然须要使用后者。
此种方法得到的tensor比较齐全,能够从中一窥模型全貌。不过,最方便的方法仍是推荐使用tensorboard来查看,固然这须要你提早将sess.graph输出。
这种操做比较简单,无非是找到原始模型的输入、输出便可。
只要搞清楚输入输出的tensor名字,便可直接使用TensorFlow中graph的get_tensor_by_name函数,创建输入输出的tensor:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0')
如上,须要特别注意,get_tensor_by_name后面传入的参数,若是没有重复,须要在后面加上“:0”。
从模型中找到了输入输出以后,便可直接使用其继续train整个模型,或者将输入数据feed到模型里,并前传获得test输出了。
须要说明的是,有时候从一个graph里找到输入和输出tensor的名字并不容易,因此,在定义graph时,最好能给相应的tensor取上一个明显的名字,好比:
data = tf.placeholder(tf.float32, shape=shape, name='input_data') preds = tf.nn.softmax(logits, name='output')
诸如此类。这样,就能够直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。
除了直接使用原始模型,还能够在原始模型上进行扩展,好比对1中的output继续进行处理,添加新的操做,能够完成对原始模型的扩展,如:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0') logits = tf.nn.softmax(output)
有时候,咱们有对某模型的一部分进行fine-tune的需求,好比使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。
with tf.get_default_graph() as graph: graph.get_tensor_by_name('fc1:0') fc1 = tf.stop_gradient(fc1) # add new procedure on fc1