import tensorflow as tf def store_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') #模型的保存必须有变量 c = tf.Variable(1, name='c') a = tf.add(x, y, name='op') result = tf.add(a, c) with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() #若是只保存其中一部分变量,则使用下面代码,用列表或者字典均可以 #saver = tf.train.Saver([x, y]) #这里面有参数global_step=50,当训练50步便保存模型 saver.save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(result, feed_dict)) def main(): ckpt_file_path = "./ckpt/model.ckpt" store_model_ckpt(ckpt_file_path) if __name__ == '__main__': main()
结果:6node
程序生成并保存四个文件python
针对上面的模型保存例子,还原模型的过程以下:git
import tensorflow as tf def restore_model_ckpt(): with tf.Session() as sess: #step1:加载模型结构 saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') #step2:只须要指定目录就能够恢复全部变量信息 saver.restore(sess,tf.train.latest_checkpoint('./ckpt')) #直接获取保存的变量 print(sess.run('c:0')) #获取placeholder变量,经过get_tensor_by_name x = sess.graph.get_tensor_by_name('x:0') y = sess.graph.get_tensor_by_name('y:0') #获取须要进行计算的op算子,此op为加法 op = sess.graph.get_tensor_by_name('op:0') #加入新的op操做,新的op为乘法 new_op = tf.multiply(op, 2) #test feed_dict = {x:2, y:3} result = sess.run(new_op,feed_dict) print(result) def main(): restore_model_ckpt() if __name__ == '__main__': main()
结果:10浏览器
1. 首先还原模型结构网络
2. 而后还原变量(参数)信息架构
3. 最后咱们就能够得到已训练的模型中的各类信息了(保存的变量、placeholder变量、operator等),同时能够对获取的变量添加各类新的操做(见以上代码注释)。
而且,咱们也能够加载部分模型,在此基础上加入其它操做,具体能够参考官方文档和demo。dom
针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,能够参考。函数
同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也很是好,能够参考。源码分析
import tensorflow as tf from tensorflow.python.framework import graph_util def store_model_pb(pb_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') b = tf.Variable(1, name='b') a = tf.add(x, y) #该op算子应该加上name op = tf.add(a, b, name='op') with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) #导出当前计算图的GraphDef部分,只须要这一部分就能够完成从输入层到输出层的计算 graph_def = tf.get_default_graph().as_graph_def() #将图中的变量及其取值转化为常量,同时将图中的没必要要的节点去掉 output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['op']) with tf.gfile.FastGFile(pb_file_path, mode='wb') as f: f.write(output_graph_def.SerializeToString()) #test feed_dict = {x: 2, y: 3} print(sess.run(op, feed_dict)) def main(): pb_file_path = "model.pb" store_model_pb(pb_file_path) if __name__ == '__main__': main()
结果:6 测试
在当前文件下面生成model.pb文件
import tensorflow as tf from tensorflow.python.platform import gfile def restore_model_pb(pb_file_path): with tf.Session() as sess: with gfile.FastGFile(pb_file_path, 'rb') as f: graph_def = tf.GraphDef() #转换成字符串形式 graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name='') #获取placeholder的变量 x = sess.graph.get_tensor_by_name('x:0') y = sess.graph.get_tensor_by_name('y:0') #获取op算子 op = sess.graph.get_tensor_by_name('op:0') feed_dict = {x: 2, y:3} result = sess.run(op,feed_dict) print(result) def main(): pb_file_path = "model.pb" restore_model_pb(pb_file_path) if __name__ == '__main__': main()
结果:5
但不少时候,咱们须要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其余地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认状况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,所以须要采用别的方法。 咱们知道,graph_def文件中没有包含网络中的Variable值(一般状况存储了权重),可是却包含了constant值,因此若是咱们能把Variable转换为constant,便可达到使用一个文件同时存储网络架构与权重的目标。
TensoFlow为咱们提供了convert_variables_to_constants()方法,该方法能够固化模型结构,将计算图中的变量取值以常量的形式保存,并且保存的模型能够移植到Android平台。
将CKPT 转换成 PB格式的文件的过程可简述以下:
1. 经过传入 CKPT 模型的路径获得模型的图和变量数据
2. 经过 import_meta_graph 导入模型中的图
3. 经过 saver.restore 从模型中恢复图中各个变量的数据
4. 经过 graph_util.convert_variables_to_constants 将模型持久化
Code:freeze_graph.py
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(ckpt_file_path, pb_file_path): #“input:0”是张量的名称,而"input"表示的是节点的名称。 #此处输入的应该是节点的名称 output_node_names = "op" #首先恢复图结构 saver = tf.train.import_meta_graph(ckpt_file_path+'.meta',clear_devices=True) graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() with tf.Session() as sess: #恢复图并获得数据 saver.restore(sess,ckpt_file_path) output_graph_def = graph_util.convert_variables_to_constants( sess=sess, input_graph_def=input_graph_def, #若是有多个输出节点 output_node_names=output_node_names.split(",")) with tf.gfile.GFile(pb_file_path,"wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) def main(): # 输入ckpt模型路径 model_folder = "D:\AI\Ckpt\TestCkpt\ckpt" #检查目录下ckpt文件状态是否可用 checkpoint = tf.train.get_checkpoint_state(model_folder) #得ckpt文件路径 ckpt_file_path = checkpoint.model_checkpoint_path # 输出pb模型的路径 pb_file_path="frozen_model.pb" # 调用freeze_graph将ckpt转为pb freeze_graph(ckpt_file_path,pb_file_path) if __name__ == '__main__': main()
结果:生成 frozen_model.pb文件,能够采用上面pb模型加载的方法测试该pb文件
说明:
一、函数freeze_graph中,最重要的就是要肯定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操做,咱们须要定义输出结点的名字。由于网络实际上是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所须要的子图都固化下来,其余无关的就舍弃掉。由于咱们freeze模型的目的是接下来作预测。因此,output_node_names通常是网络模型最后一层输出的节点名称,或者说就是咱们预测的目标。
二、在保存的时候,经过convert_variables_to_constants函数来指定须要固化的节点名称,对于鄙人的代码,须要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。
三、源码中经过graph = tf.get_default_graph()得到默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,所以必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。
四、上面以及说明:在保存的时候,经过convert_variables_to_constants函数来指定须要固化的节点名称,对于鄙人的代码,须要固化的节点只有一个:output_node_names。所以,其余网络模型,也能够经过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。
PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"
# -*- coding: utf-8 -*- """ Created on Sat Dec 22 09:49:04 2018 @author: weilong """ import tensorflow as tf #定义简单的计算图,实现向量加法的操做 with tf.name_scope("imput1"): input1 = tf.constant([1.0, 2.0, 3.0], name="input1") with tf.name_scope("input2"): input2 = tf.Variable(tf.random_uniform([3]), name="input2") output = tf.add_n([input1, input2], name="add") #生成写日志的writer,并将当前的tensorflow计算图写入日志 writer = tf.summary.FileWriter("./log", tf.get_default_graph()) writer.close()
import tensorflow as tf model = 'model.pb' #请将这里的pb文件路径改成本身的 graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') summaryWriter = tf.summary.FileWriter('log/', graph)
执行以上代码就会生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。
在tensorboard中加载:
tensorboard --logdir=\path\to\log
在浏览器中
拷贝网站连接在浏览器中便可。
参考:https://blog.csdn.net/guyuealian/article/details/82218092