当咱们使用 tensorflow 训练神经网络的时候,模型持久化对于咱们的训练有很重要的做用。node
若是咱们的神经网络比较复杂,训练数据比较多,那么咱们的模型训练就会耗时很长,若是在训练过程当中出现某些不可预计的错误,致使咱们的训练意外终止,那么咱们将会前功尽弃。为了不这个问题,咱们就能够经过模型持久化(保存为CKPT格式)来暂存咱们训练过程当中的临时数据。python
若是咱们训练的模型须要提供给用户作离线的预测,那么咱们只须要前向传播的过程,只需获得预测值就能够了,这个时候咱们就能够经过模型持久化(保存为PB格式)只保存前向传播中须要的变量并将变量的值固定下来,这个时候只需用户提供一个输入,咱们就能够经过模型获得一个输出给用户。网络
# coding=UTF-8 支持中文编码格式 import tensorflow as tf import shutil import os.path MODEL_DIR = "model/ckpt" MODEL_NAME = "model.ckpt" # if os.path.exists(MODEL_DIR): 删除目录 # shutil.rmtree(MODEL_DIR) if not tf.gfile.Exists(MODEL_DIR): #建立目录 tf.gfile.MakeDirs(MODEL_DIR) #下面的过程你能够替换成CNN、RNN等你想作的训练过程,这里只是简单的一个计算公式 input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的 W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1") B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1") _y = (input_holder * W1) + B1 predictions = tf.greater(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,不然返回false init = tf.global_variables_initializer() saver = tf.train.Saver() #声明saver用于保存模型 with tf.Session() as sess: sess.run(init) print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #输入一个数据测试一下 saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存 print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #获得当前图有几个操做节点 for op in tf.get_default_graph().get_operations(): #打印模型节点信息 print (op.name, op.values())
运行后生成的文件以下:框架
# coding=UTF-8 import tensorflow as tf import shutil import os.path from tensorflow.python.framework import graph_util # MODEL_DIR = "model/pb" # MODEL_NAME = "addmodel.pb" # if os.path.exists(MODEL_DIR): 删除目录 # shutil.rmtree(MODEL_DIR) # # if not tf.gfile.Exists(MODEL_DIR): #建立目录 # tf.gfile.MakeDirs(MODEL_DIR) output_graph = "model/pb/add_model.pb" #下面的过程你能够替换成CNN、RNN等你想作的训练过程,这里只是简单的一个计算公式 input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1") B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1") _y = (input_holder * W1) + B1 # predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,不然返回false predictions = tf.add(_y, 10,name="predictions") #作一个加法运算 init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) graph_def = tf.get_default_graph().as_graph_def() #获得当前的图的 GraphDef 部分,经过这个部分就能够完成重输入层到输出层的计算过程 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess, graph_def, ["predictions"] #须要保存节点的名字 ) with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) print (predictions) # for op in tf.get_default_graph().get_operations(): 打印模型节点信息 # print (op.name)
*GraphDef:这个属性记录了tensorflow计算图上节点的信息。学习
# coding=UTF-8 import tensorflow as tf import os.path import argparse from tensorflow.python.framework import graph_util MODEL_DIR = "model/pb" MODEL_NAME = "frozen_model.pb" if not tf.gfile.Exists(MODEL_DIR): #建立目录 tf.gfile.MakeDirs(MODEL_DIR) def freeze_graph(model_folder): checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用 input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径 output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径 output_node_names = "predictions" #原模型输出操做节点的名字 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #获得图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import. graph = tf.get_default_graph() #得到默认的图 input_graph_def = graph.as_graph_def() #返回一个序列化的图表明当前的图 with tf.Session() as sess: saver.restore(sess, input_checkpoint) #恢复图并获得数据 print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操做节点的名字 output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定 sess, input_graph_def, output_node_names.split(",") #若是有多个输出节点,以逗号隔开 ) with tf.gfile.GFile(output_graph, "wb") as f: #保存模型 f.write(output_graph_def.SerializeToString()) #序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) #获得当前图有几个操做节点 for op in graph.get_operations(): print(op.name, op.values()) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型, # 这里运行程序时须要带上模型ckpt的路径,否则会报 error: too few arguments aggs = parser.parse_args() freeze_graph(aggs.model_folder) # freeze_graph("model/ckpt") #模型目录
部分参考:测试
TensorFlow实战Google深度学习框架、http://blog.csdn.net/lujiandong1/article/details/53385092编码