1. High level API checkpointspython
只针对与 estimatorlua
设置检查点的时间频率和总个数spa
my_checkpointing_config = tf.estimator.RunConfig( save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes. keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints. )
实例化时传递给 estimator 的 config 参数rest
model_dir 设置存储路径code
classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3, model_dir='models/iris', config=my_checkpointing_config)
一旦检查点文件存在,TensorFlow 总会在你调用 train()
、 evaluation()
或 predict()
时重建模型教程
------------------------------------------------------------------------------------------------------------get
2.Low level API tf.train.Saverit
-------------------------------------------------------------------------------------------------------------io
Saver.save 存储 model 中的全部变量class
import tensorflow as tf # 建立变量 var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer) # 添加初始化变量的操做 init_op = tf.global_variables_initializer() # 添加保存和恢复这些变量的操做 saver = tf.train.Saver() # 而后,加载模型,初始化变量,完成一些工做,并保存这些变量到磁盘中 with tf.Session() as sess: sess.run(init_op) # 使用模型完成一些工做 var.op.run() # 将变量保存到磁盘中 save_path = saver.save(sess, "/tmp/model.ckpt") print("Model saved in path: %s" % save_path)
var = tf.get_variable("var", shape=[3], initializer = tf.zeros_initializer) # tf.get_variable: Gets an existing variable with these parameters or create a new one. # shape: Shape of the new or existing variable # initializer: Initializer for the variable if one is created. tf.zeros_initializer 赋值为0 [0 0 0]
saver = tf.train.Saver() # Saver 来管理模型中的全部变量,注意是全部变量
tf.Session() # A class for running TensorFlow operations.
with...as... #执行 with 后面的语句,若是能够执行则将赋值给 as 后的语句。若是出现错误则执行 with 后语句中的 __exit__ #来报错。相似与 try if,可是更方便
Saver.save 选择性的存储变量
saver = tf.train.Saver({'var2':var2})
-------------------------------------------------------------------------------------------------------------
Saver.restore 加载路径中的全部变量
import tensorflow as tf tf.reset_default_graph() # 建立一些变量 var = tf.get_variable("var", shape=[3]) # 添加保存和恢复这些变量的操做 saver = tf.train.Saver() # 而后,加载模型,使用 saver 从磁盘中恢复变量,并使用变量完成一些工做 with tf.Session() as sess: # 从磁盘中恢复变量 saver.restore(sess, "/tmp/model.ckpt") print("Model restored.") # 检查变量的值 print("var : %s" % var.eval())
-------------------------------------------------------------------------------------------------------------
inspector_checkpoint 检查存储的变量
加载 inspect_checkpoints
from tensorflow.python.tools import inspect_checkpoint as chkp
打印存储起来的全部变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True, all_tensor_names=False)
注意其中的参数 all_tensor_names 教程中并未添加这个参数,运行时持续报错 missing
打印制定的变量
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='var1', all_tensors=False, all_tensor_names=False)