Tensorflow 模型保存与调用

说明:训练模型,保存相关参数,以便在之后验证时直接输入验证数据集便可获得模型模拟结果。html

主要参考了官方教程和博客 http://www.javashuo.com/article/p-mibtkwkd-ed.htmlrest

 

1、 模型存储htm

mymodel.meta -----------保存完整Tensorflow graph的protocol buffer,好比说,全部的 variables, operations, collections等等blog

mymodel.data-00000-of-00001 ----------.data文件中包含了训练变量,如权重(weights),偏置(biases),梯度(gradients)和全部其余保存的变量(variables)。教程

mymodel.indexget

checkpoint -----------记录最新保存的模型的存储路径。博客

 

二、保存模型it

使用tf.train.Saver() 类io

例:saver=tf.train.Saver(tf.global_variables(),max_to_keep=20)import

若是在tf.train.Saver()中没有指定任何东西,将保存全部变量。

若是不想保存全部的变量,只想保存其中一些变量,能够在建立tf.train.Saver实例的时候,给它传递一个想要保存的变量的list或者字典。

 

三、调用一个已经训练好的模型

使用tf.train.import_meta_graph()、saver.restore() 和 tf.get_default_graph()

例:with tf.Session() as sess:

              saver=tf.train.import_meta_graph('train.model-1000.meta')     #指定参数的读取路径
              saver.restore(sess,('train.model-1000'))                                   #提取参数
              graph = tf.get_default_graph()                                                  #获取模型结构(张量图graph)

             #经过变量名加载变量的值

             X=graph.get_tensor_by_name('X:0')

            #注意:若想经过变量名称加载变量,要求已保存的模型中为变量指明了变量名

 

四、模型再训练

 在三、中把模型的结构和参数提取出来后,直接按本身的需求编写模型训练的代码便可。

相关文章
相关标签/搜索