标签(空格分隔): TensorFlowgit
tensorflow模型保存函数为:dom
tf.train.Saver()
固然,除了上面最简单的保存方式,也能够指定保存的步数,多长时间保存一次,磁盘上最多保有几个模型(将前面的删除以保持固定个数),以下:函数
建立saver时指定参数:测试
saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)
其中:rest
保存模型时指定参数:code
saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)
如上,其中能够指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。orm
示例:ci
import tensorflow as tf v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1") v2= tf.Variable(tf.zeros([200]), name="v2") v3= tf.Variable(tf.zeros([100]), name="v3") saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) saver.save(sess,"checkpoint/model.ckpt",global_step=1)
运行后,保存模型保存,获得四个文件:get
checkpoint中记录了已存储(部分)和最近存储的模型:input
model_checkpoint_path: "model.ckpt-1" all_model_checkpoint_paths: "model.ckpt-1" ...
meta file保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,咱们能够不在文件中定义模型,也能够运行,而若是没有meta file,咱们须要定义好模型,再加载data file,获得变量值。
index file为一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每一个BundleEntryProto表述了tensor的metadata,好比那个data文件包含tensor、文件中的偏移量、一些辅助数据等。
data file保存了模型的全部变量的值,TensorBundle集合。
Restore模型的过程能够分为两个部分,首先是建立模型,能够手动建立,也能够从meta文件里加载graph进行建立。
模型加载为:
with tf.Session() as sess: saver = tf.train.import_meta_graph('/xx/model.ckpt.meta') saver.restore(sess, "/xx/model.ckpt")
.meta文件中保存了图的结构信息,所以须要在导入checkpoint以前导入它。不然,程序不知道checkpoint中的变量对应的变量。另外也能够:
# Recreate the EXACT SAME variables v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Now load the checkpoint variable values with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, "/xx/model.ckpt") #saver.restore(sess, tf.train.latest_checkpoint('./'))
PS:不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix... instead of any physical pathname.
固然,还有一点须要注意,并不是全部的TensorFlow模型都能将graph输出到meta文件中或者从meta文件中加载进来,若是模型有部分不能序列化的部分,则此种方法可能会无效。
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) tvs = [v for v in tf.trainable_variables()] for v in tvs: print(v.name) print(sess.run(v))
如名所言,以上是查看模型中的trainable variables;或者咱们也能够查看模型中的全部tensor或者operations,以下:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) gv = [v for v in tf.global_variables()] for v in gv: print(v.name)
上面经过global_variables()得到的与前trainable_variables相似,只是多了一些非trainable的变量,好比定义时指定为trainable=False的变量,或Optimizer相关的变量。
下面则能够得到几乎全部的operations相关的tensor:
with tf.Session() as sess: saver = tf.train.import_meta_graph('model.ckpt-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = [o for o in sess.graph.get_operations()] for o in ops: print(o.name)
首先,上面的sess.graph.get_operations()能够换为tf.get_default_graph().get_operations(),两者区别无非是graph明确的时候能够直接使用前者,不然须要使用后者。
此种方法得到的tensor比较齐全,能够从中一窥模型全貌。不过,最方便的方法仍是推荐使用tensorboard来查看,固然这须要你提早将sess.graph输出。
这种操做比较简单,无非是找到原始模型的输入、输出便可。
只要搞清楚输入输出的tensor名字,便可直接使用TensorFlow中graph的get_tensor_by_name函数,创建输入输出的tensor:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0')
从模型中找到了输入输出以后,便可直接使用其继续train整个模型,或者将输入数据feed到模型里,并前传获得test输出了。
须要说明的是,有时候从一个graph里找到输入和输出tensor的名字并不容易,因此,在定义graph时,最好能给相应的tensor取上一个明显的名字,好比:
data = tf.placeholder(tf.float32, shape=shape, name='input_data') preds = tf.nn.softmax(logits, name='output')
诸如此类。这样,就能够直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。
除了直接使用原始模型,还能够在原始模型上进行扩展,好比对1中的output继续进行处理,添加新的操做,能够完成对原始模型的扩展,如:
with tf.get_default_graph() as graph: data = graph.get_tensor_by_name('data:0') output = graph.get_tensor_by_name('output:0') logits = tf.nn.softmax(output)
有时候,咱们有对某模型的一部分进行fine-tune的需求,好比使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。
with tf.get_default_graph() as graph: graph.get_tensor_by_name('fc1:0') fc1 = tf.stop_gradient(fc1) # add new procedure on fc1