Saving, Freezing, Optimizing for inference, Restoring of tensorflow modelsnode
在训练完tensorflow模型后,会有三个文件:model-epoch_99.data-00000-of-00001,model-epoch_99.index,model-epoch_99.metapython
1.tensorflowModel.ckpt.meta:Tenosrflow将图结构与变量值分开存储。 文件.ckpt.meta包含完整的图结构。 它包括GraphDef,SaverDef等。
2.tensorflowModel.ckpt.data-00000-of-00001:它包含的变量(重量,误差,占位符,梯度,超参数等)的值。
3.tensorflowModel.ckpt.index:这是一个表,其中每一个键是张量tensor的名称,其值是序列化的BundleEntryProto。git
import resnet_multitask def classify_model(images, class_num): with slim.arg_scope(resnet_multitask.resnet_arg_scope(is_training=False)): logits, pre_heatmap, end_points = resnet_multitask.resnet_v2(images, class_num) return logits, pre_heatmap, end_points restore_path = './checkpoint/model-epoch_99' with tf.Session() as sess: input_x = tf.placeholder(tf.float32, shape=[None, w, h, c], name='input_x') logits,pre_heatmap,end_points = classify_model(input_x,class_num) saver = tf.train.Saver() saver.restore(sess, restore_path) ## generate graph tf.train.write_graph(sess.graph.as_graph_def(), '.', './checkpoint/tensorflowModel.pbtxt', as_text=True)
from tensorflow.python.tools import freeze_graph freeze_graph.freeze_graph('./checkpoint/tensorflowModel.pbtxt', "", False, './checkpoint/model-epoch_99_acc_0.968202', "resnet_v2/predictions/Reshape_1", "save/restore_all", "save/Const:0", './checkpoint/model.pb', True, "" )
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist="", variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, saved_model_tags=tag_constants.SERVING, checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants. Args: input_graph: A `GraphDef` file to load. input_saver: A TensorFlow Saver file. input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt. input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking priority. Typically the result of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or V1/V2. output_node_names: The name(s) of the output nodes, comma separated. restore_op_name: Unused. filename_tensor_name: Unused. output_graph: String where to write the frozen `GraphDef`. clear_devices: A Bool whether to remove device specifications. initializer_nodes: Comma separated list of initializer nodes to run before freezing. variable_names_whitelist: The set of variable names to convert (optional, by default, all variables are converted), variable_names_blacklist: The set of variable names to omit converting to constants (optional). input_meta_graph: A `MetaGraphDef` file to load (optional). input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and variables (optional). saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to load, in string format. checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2). Returns: String that is the location of frozen GraphDef. """
import tensorflow as tf from tensorflow.python.platform import gfile model = './checkpoint/model.pb' graph = tf.get_default_graph() graph_def = graph.as_graph_def() graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') summaryWriter = tf.summary.FileWriter('./logs/model', graph)
写一个start_tensorboard.bat,内容以下,而后运行,打开浏览器,地址栏输入http://localhost:6006github
cd C:\software\Anaconda3\Scripts tensorboard.exe --logdir=C:\workspace\code\img_classify\logs\model
import tensorflow as tf import numpy as np import time import cv2 def recognize(jpg_path, pb_file_path): with tf.Graph().as_default(): output_graph_def = tf.GraphDef() with open(pb_file_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) input_x = sess.graph.get_tensor_by_name("input_x:0") out_softmax = sess.graph.get_tensor_by_name("resnet_v2/predictions/Reshape_1:0") img = cv2.imread(jpg_path) img_ori = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) test_img = cv2.resize(img_ori, (224, 224)) test_img = np.asarray(test_img, np.float32) test_img = test_img[np.newaxis, :] / 255. time_start = time.time() img_out_softmax = sess.run(out_softmax, feed_dict={input_x:test_img}) time_end = time.time() print('run time: ', time_end - time_start, 's') print("img_out_softmax:",img_out_softmax) prediction_labels = np.argmax(img_out_softmax) print("label:",prediction_labels) recognize(r'C:\data\test_image.jpg', "./checkpoint/model.pb")