【tensorflow】生成.pb文件

Saving, Freezing, Optimizing for inference, Restoring of tensorflow modelsnode

在训练完tensorflow模型后,会有三个文件:model-epoch_99.data-00000-of-00001,model-epoch_99.index,model-epoch_99.metapython

1.tensorflowModel.ckpt.meta:Tenosrflow将图结构与变量值分开存储。 文件.ckpt.meta包含完整的图结构。 它包括GraphDef,SaverDef等。
2.tensorflowModel.ckpt.data-00000-of-00001:它包含的变量(重量,误差,占位符,梯度,超参数等)的值。
3.tensorflowModel.ckpt.index:这是一个表,其中每一个键是张量tensor的名称,其值是序列化的BundleEntryProto。git

  • 第一步先生成tensorflowModel.pbtxt文件。能够在测试程序中,执行完saver.restore以后,将graph保存为.pbtxt。
     
import resnet_multitask

def classify_model(images, class_num):
    with slim.arg_scope(resnet_multitask.resnet_arg_scope(is_training=False)):
        logits, pre_heatmap, end_points = resnet_multitask.resnet_v2(images, class_num)

    return logits, pre_heatmap, end_points

restore_path = './checkpoint/model-epoch_99'

with tf.Session() as sess:
    input_x = tf.placeholder(tf.float32, shape=[None, w, h, c], name='input_x')
    logits,pre_heatmap,end_points = classify_model(input_x,class_num)

    saver = tf.train.Saver()
    saver.restore(sess, restore_path)

    ## generate graph
    tf.train.write_graph(sess.graph.as_graph_def(), '.', './checkpoint/tensorflowModel.pbtxt', as_text=True)
from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph('./checkpoint/tensorflowModel.pbtxt', "", False,
                          './checkpoint/model-epoch_99_acc_0.968202', "resnet_v2/predictions/Reshape_1",
                           "save/restore_all", "save/Const:0",
                           './checkpoint/model.pb', True, ""
                         )
def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.
  Args:
    input_graph: A `GraphDef` file to load.
    input_saver: A TensorFlow Saver file.
    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated list of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted),
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph: A `MetaGraphDef` file to load (optional).
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
                           variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format.
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2).
  Returns:
    String that is the location of frozen GraphDef.
  """
  • 生成.pb文件后,能够经过tensorboard可视化
import tensorflow as tf
from tensorflow.python.platform import gfile


model = './checkpoint/model.pb'
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('./logs/model', graph)

写一个start_tensorboard.bat,内容以下,而后运行,打开浏览器,地址栏输入http://localhost:6006github

cd C:\software\Anaconda3\Scripts
tensorboard.exe --logdir=C:\workspace\code\img_classify\logs\model

  • 载入.pb模型进行前向运算
import tensorflow as tf
import  numpy as np
import time
import cv2

def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")

        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run(init)

            input_x = sess.graph.get_tensor_by_name("input_x:0")
            out_softmax = sess.graph.get_tensor_by_name("resnet_v2/predictions/Reshape_1:0")

            img = cv2.imread(jpg_path)
            img_ori = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            test_img = cv2.resize(img_ori, (224, 224))
            test_img = np.asarray(test_img, np.float32)
            test_img = test_img[np.newaxis, :] / 255.

            time_start = time.time()
            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:test_img})
            time_end = time.time()
            print('run time: ', time_end - time_start, 's')

            print("img_out_softmax:",img_out_softmax)
            prediction_labels = np.argmax(img_out_softmax)
            print("label:",prediction_labels)

recognize(r'C:\data\test_image.jpg', "./checkpoint/model.pb")