首先定义一个tf.train.Saver类:spa
saver = tf.train.Saver(max_to_keep=1)
其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,若是设置成0,训练过程当中的全部模型都会被保存。rest
模型训练好之后,保存模型:code
saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)
其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,因为咱们只保存最后一个模型,因此能够设置为1,若是每个模型都想保存,能够设置成训练的epoch。blog
载入模型比较简单:io
saver.restore(sess, model_file)
其中,sess是Session,model_file是模型的路径和名称。class