Tensorflow从1.3版本开始推出了官方支持的高层封装tf.estimator
。Estimators API提供了一整套训练模型、测试模型以及生成预测的方法。python
Tensorflow支持自定义estimator,首先须要定义一个模型函数model_fn,函数有4个输入:features,labels,mode和params。
features为模型的输入,labels为预测的真实值
mode的取值有3种:tf.estimator.ModeKeys.TRAIN
,tf.estimator.ModeKeys.EVAL
和tf.estimator.ModeKeys.PREDICT
,分别对应训练,验证和测试。经过mode的值,能够判断当前属于哪个阶段。params是一个字典,包含模型相关的超参数,例如learning rate等。
自定义函数model_fn返回值必须是一个tf.estimator.EstimatorSpec
对象,git
def __new__(cls, mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None, export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None):
其中,mode
表示模型的使用模式,对应model_fn的参数mode;predictions
表示根据输入的特征features
计算返回的预测值;loss
表示损失;train_op
表示对模型的损失进行最小化的op;eval_metric_ops
表示模型在eval时,须要额外输出的指标。export_outputs
表示导出模型的路径。还有一些钩子函数。
当mode不一样,EstimatorSpec所需的参数也不同。若是mode为TRAIN
,则实例化EstimatorSpec时,必须设置参数loss
和train_op
,当mode为EVAL
时,必须设置参数loss
,当mode为PREDICT
时,必须设置参数predictions
。github
def my_model(features, labels, mode, params): W = tf.Variable(tf.random_normal([1]), name="weight") b = tf.Variable(tf.zeros([1]), name="bias") predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode, predictions=predictions) loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions) mean_loss = tf.metrics.mean(loss) metrics = {'mean_loss':mean_loss} if mode == tf.estimator.ModeKeys.EVAL: # eval_metric_ops`用来定义评价指标,在运行eval的时候会计算这里定义的全部评测标准。 return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics) assert mode == tf.estimator.ModeKeys.TRAIN optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step()) train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step()) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
最后经过实例化tf.estimator.Estimator
就能够获得一个自定义的estimator。dom
def __init__(self, model_fn: Any, model_dir: Any = None, config: Any = None, params: Any = None, warm_start_from: Any = None) -> Any
参数model_fn
即为自定义的模型函数,model_dir
用于保存模型的参数和模型图等内容。warm_start_from
用来指定检查点路径,并导入checkpoint开始训练。warm_start_from能够经过tf.estimator.WarmStartSettings
实例化。函数
def __new__(cls, ckpt_to_initialize_from: Any, vars_to_warm_start: str = '.*', var_name_to_vocab_info: Any = None, var_name_to_prev_var_name: Any = None) -> _T
ckpt_to_initialize_from
能够指定加载checkpoint的路径,vars_to_warm_start
指定哪些参数须要热启动。学习
代码自定义estimator测试