TensorFlow

官网

Tensorflow源码分析

A、基本概念

  1. Graph

  2. Tensor 

  3. Session

B、Tools

  1. Checkpoint  .Ckpt

  2. Pb

  3. .Ckpt To .Pb

  4. TensorBoard

 B.1  .Ckpt 模型加载

1. 模型的保存

import tensorflow as tf

def store_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    #模型的保存必须有变量
    c = tf.Variable(1, name='c')
    a = tf.add(x, y, name='op')
    result = tf.add(a, c)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
    
        saver = tf.train.Saver()
    
        #若是只保存其中一部分变量,则使用下面代码,用列表或者字典均可以
        #saver = tf.train.Saver([x, y])
    
        #这里面有参数global_step=50,当训练50步便保存模型
        saver.save(sess, ckpt_file_path)
        # test
        feed_dict = {x: 2, y: 3}
        print(sess.run(result, feed_dict))

def main():
    ckpt_file_path = "./ckpt/model.ckpt"
    store_model_ckpt(ckpt_file_path)

if __name__ == '__main__':
    main()

结果:6node

程序生成并保存四个文件python

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

2. 模型恢复加载

针对上面的模型保存例子,还原模型的过程以下:git

import tensorflow as tf

def restore_model_ckpt():
    with tf.Session() as sess:
        #step1:加载模型结构
        saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')
        #step2:只须要指定目录就能够恢复全部变量信息
        saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))
        
        #直接获取保存的变量
        print(sess.run('c:0'))
        
        #获取placeholder变量,经过get_tensor_by_name
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #获取须要进行计算的op算子,此op为加法
        op = sess.graph.get_tensor_by_name('op:0')
        
        #加入新的op操做,新的op为乘法
        new_op = tf.multiply(op, 2)
        
        #test
        feed_dict = {x:2, y:3}
        
        result = sess.run(new_op,feed_dict)
        print(result)

def main():
    restore_model_ckpt()
    
if __name__ == '__main__':
    main()

结果:10浏览器

  1. 首先还原模型结构网络

  2. 而后还原变量(参数)信息架构

  3. 最后咱们就能够得到已训练的模型中的各类信息了(保存的变量、placeholder变量、operator等),同时能够对获取的变量添加各类新的操做(见以上代码注释)。
  而且,咱们也能够加载部分模型,在此基础上加入其它操做,具体能够参考官方文档和demo。dom

  针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,能够参考。函数

  同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也很是好,能够参考。源码分析

B. 2 Pb模型文件

 1. pb模型保存

import tensorflow as tf
from tensorflow.python.framework import graph_util

def store_model_pb(pb_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    a = tf.add(x, y)
    #该op算子应该加上name
    op = tf.add(a, b, name='op')
    
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        
        #导出当前计算图的GraphDef部分,只须要这一部分就能够完成从输入层到输出层的计算
        graph_def = tf.get_default_graph().as_graph_def()
        
        #将图中的变量及其取值转化为常量,同时将图中的没必要要的节点去掉
        output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['op'])
        
        with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
            f.write(output_graph_def.SerializeToString())
        
        #test
        feed_dict = {x: 2, y: 3}
        print(sess.run(op, feed_dict))

def main():
    pb_file_path = "model.pb"
    store_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()
    

结果:6 测试

  在当前文件下面生成model.pb文件

2. pb模型加载

import tensorflow as tf
from tensorflow.python.platform import gfile
    
def restore_model_pb(pb_file_path):
    with tf.Session() as sess:
        with gfile.FastGFile(pb_file_path, 'rb') as f:
            graph_def = tf.GraphDef()
            #转换成字符串形式
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
       
        #获取placeholder的变量
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #获取op算子
        op = sess.graph.get_tensor_by_name('op:0')
        
        feed_dict = {x: 2, y:3}
        result = sess.run(op,feed_dict)
        print(result)
          
def main():
    pb_file_path = "model.pb"
    restore_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()

结果:5

B 3. 将.Ckpt 转换为.Pb

  但不少时候,咱们须要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其余地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认状况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,所以须要采用别的方法。 咱们知道,graph_def文件中没有包含网络中的Variable值(一般状况存储了权重),可是却包含了constant值,因此若是咱们能把Variable转换为constant,便可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为咱们提供了convert_variables_to_constants()方法,该方法能够固化模型结构,将计算图中的变量取值以常量的形式保存,并且保存的模型能够移植到Android平台。

1、CKPT 转换成 PB格式

  将CKPT 转换成 PB格式的文件的过程可简述以下:

    1. 经过传入 CKPT 模型的路径获得模型的图和变量数据
    2. 经过 import_meta_graph 导入模型中的图
    3. 经过 saver.restore 从模型中恢复图中各个变量的数据
    4. 经过 graph_util.convert_variables_to_constants 将模型持久化

Code:freeze_graph.py

import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(ckpt_file_path, pb_file_path):
    #“input:0”是张量的名称,而"input"表示的是节点的名称。
    #此处输入的应该是节点的名称
    output_node_names = "op"
    #首先恢复图结构
    saver = tf.train.import_meta_graph(ckpt_file_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    
    with tf.Session() as sess:
        #恢复图并获得数据
        saver.restore(sess,ckpt_file_path)
        output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                #若是有多个输出节点
                output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(pb_file_path,"wb") as f:
            f.write(output_graph_def.SerializeToString())
            print("%d ops in the final graph." % len(output_graph_def.node)) 
                     
def main():
    # 输入ckpt模型路径
    model_folder = "D:\AI\Ckpt\TestCkpt\ckpt"
    #检查目录下ckpt文件状态是否可用
    checkpoint = tf.train.get_checkpoint_state(model_folder) 
    #得ckpt文件路径
    ckpt_file_path = checkpoint.model_checkpoint_path 
    
    # 输出pb模型的路径
    pb_file_path="frozen_model.pb"
    
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(ckpt_file_path,pb_file_path)
    
if __name__ == '__main__':
    main()

结果:生成 frozen_model.pb文件,能够采用上面pb模型加载的方法测试该pb文件

说明:

一、函数freeze_graph中,最重要的就是要肯定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操做,咱们须要定义输出结点的名字。由于网络实际上是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所须要的子图都固化下来,其余无关的就舍弃掉。由于咱们freeze模型的目的是接下来作预测。因此,output_node_names通常是网络模型最后一层输出的节点名称,或者说就是咱们预测的目标。

 二、在保存的时候,经过convert_variables_to_constants函数来指定须要固化的节点名称,对于鄙人的代码,须要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

三、源码中经过graph = tf.get_default_graph()得到默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,所以必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

四、上面以及说明:在保存的时候,经过convert_variables_to_constants函数来指定须要固化的节点名称,对于鄙人的代码,须要固化的节点只有一个:output_node_names。所以,其余网络模型,也能够经过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

B.4 TensorBoard

  1. 生成graph

# -*- coding: utf-8 -*-
"""
Created on Sat Dec 22 09:49:04 2018

@author: weilong
"""

import tensorflow as tf

#定义简单的计算图,实现向量加法的操做
with tf.name_scope("imput1"):
    input1 = tf.constant([1.0, 2.0, 3.0], name="input1")
with tf.name_scope("input2"):
    input2 = tf.Variable(tf.random_uniform([3]), name="input2")
output = tf.add_n([input1, input2], name="add")

#生成写日志的writer,并将当前的tensorflow计算图写入日志
writer = tf.summary.FileWriter("./log", tf.get_default_graph())
writer.close()

 2. 将训练好的model.pb文件在tensorboard中展现其网络结构

import tensorflow as tf

model = 'model.pb' #请将这里的pb文件路径改成本身的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

执行以上代码就会生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。

 在tensorboard中加载:

tensorboard --logdir=\path\to\log

在浏览器中

拷贝网站连接在浏览器中便可。

参考:https://blog.csdn.net/guyuealian/article/details/82218092

相关文章
相关标签/搜索