原文地址:https://blog.csdn.net/mrr1ght/article/details/81011280 。本文有删减。html
tf.train.SessionRunHook()是一个类;用来定义Hooks;python
Hooks是什么,官方文档中关于training hooks的定义是:session
Hooks are tools that run in the process of training/evaluation of the model.函数
Hooks是在模型训练/测试过程当中的工具。Pytorch中也常常会有这个概念出现,其实也就跟keras里的callbacks同样,hook和callback都是在训练过程当中执行特定的任务。工具
例如判断是否须要中止训练的EarlyStopping;改变学习率的LearningRateScheduler,他们都有一个共性,就是在每一个step开始/结束或者每一个epoch开始/结束时须要执行某个操做。如每一个epoch结束都保存一次checkpoint;每一个epoch结束时都判断一次loss有没有降低,若是loss没有降低的轮数大于提取设定的阈值,就终止训练。固然以上的功能咱们均可以本身彻底重头实现。可是这些keras和tersorflow提供了更好的工具就是hook和callback,而且一些经常使用的功能都已经实现好了。说到底每一个hook和callback都是按照固定格式定义了在每一个step开始/结束要执行的操做,每一个epoch开始/结束执行的操做。学习
Hooks都是继承自父类tf.train.SessionRunHook()
,首先看一下这个父类的定义源码;测试
tf.train.SessionRunHook()
类定义在tensorflow/python/training/session_run_hook.py
,类中每一个函数的做用与何时调用都已加入函数注释中;lua
class SessionRunHook(object): """Hook to extend calls to MonitoredSession.run().""" def begin(self): """再建立会话以前调用 调用begin()时,default graph会被建立, 可在此处向default graph增长新op,begin()调用后,default graph不能再被修改 """ pass def after_create_session(self, session, coord): # pylint: disable=unused-argument """tf.Session被建立后调用 调用后会指示全部的Hooks有一个新的会话被建立 Args: session: A TensorFlow Session that has been created. coord: A Coordinator object which keeps track of all threads. """ pass def before_run(self, run_context): # pylint: disable=unused-argument """调用在每一个sess.run()执行以前 能够返回一个tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些op/tensor; 加入的op/tensor会和sess.run()中已定义的op/tensor合并,而后一块儿执行; Args: run_context: A `SessionRunContext` object. Returns: None or a `SessionRunArgs` object. """ return None def after_run(self, run_context, # pylint: disable=unused-argument run_values): # pylint: disable=unused-argument """调用在每一个sess.run()以后 参数run_values是befor_run()中要求的op/tensor的返回值; 能够调用run_context.qeruest_stop()用于中止迭代 sess.run抛出任何异常after_run不会被调用 Args: run_context: A `SessionRunContext` object. run_values: A SessionRunValues object. """ pass def end(self, session): # pylint: disable=unused-argument """在会话结束时调用 end()常被用于Hook想要执行最后的操做,如保存最后一个checkpoint 若是sess.run()抛出除了表明迭代结束的OutOfRange/StopIteration异常外, end()不会被调用 Args: session: A TensorFlow Session that will be soon closed. """ pass
tf.train.SessionRunHook()
类中定义的方法的参数run_context
,run_values
,run_args
,包含sess.run()
会话运行所需的一切信息,spa
run_context
:类tf.train.SessRunContext
的实例run_values
:类tf.train.SessRunValues
的实例run_args
:类tf.train.SessRunArgs
的实例.这三个类会在下面详细介绍.net
(1)可使用tf中已经预约义好的Hook,其都是tf.train.SessionRunHook()的子类;如
(2)也可用tf.train.SessionRunHook()定义本身的Hook,并重写类中的方法;而后把想要使用的Hook(预约义好的或者本身定义的)放到tf.train.MonitorTrainingSession()参数[Hook]列表中;
关于tf.train.MonitorTrainingSession()
参见tf.train.MonitoredTrainingSession()解析。
给一个定义本身Hook的栗子,来自cifar10
class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time#duration持续的时间 self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch))
这三个类都服务于sess.run(),区别以下:
(1) tf.train.SessRunArgs类
提供给会话运行的参数,与sess.run()参数定义同样:
fethes,feeds,option
(2) tf.train.SessRunValues
用于保存sess.run()的结果,其中resluts是sess.run()返回值中对应于SessRunArgs()的返回值,
(3) tf.train.SessRunContext
SessRunContext包含sess.run()所需的一切信息
属性:
方法:
equest_stop(): 设置_stop_request值为True
tf.train.SessionRunHook()和tf.train.MonitorTrainingSession()通常一块儿使用,下面是cifar10中的使用实例
class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time#duration持续的时间 self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #monitored 被监控的 with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op)