如何使用TensorFlow中的高级API:Estimator、Experiment和Dataset

近日,背景调查公司 Onfido 研究主管 Peter Roelants 在 Medium 上发表了一篇题为《Higher-Level APIs in TensorFlow》的文章,经过实例详细介绍了如何使用 TensorFlow 中的高级 API(Estimator、Experiment 和 Dataset)训练模型。值得一提的是 Experiment 和 Dataset 能够独立使用。这些高级 API 已被最新发布的 TensorFlow1.3 版收录。
TensorFlow 中有许多流行的库,如 Keras、TFLearn 和 Sonnet,它们可让你轻松训练模型,而无需接触哪些低级别函数。目前,Keras API 正倾向于直接在 TensorFlow 中实现,TensorFlow 也在提供愈来愈多的高级构造,其中的一些已经被最新发布的 TensorFlow1.3 版收录。
在本文中,咱们将经过一个例子来学习如何使用一些高级构造,其中包括 Estimator、Experiment 和 Dataset。阅读本文须要预先了解有关 TensorFlow 的基本知识。
Experiment、Estimator 和 DataSet 框架和它们的相互做用(如下将对这些组件进行说明)
在本文中,咱们使用 MNIST 做为数据集。它是一个易于使用的数据集,能够经过 TensorFlow 访问。你能够在这个 gist 中找到完整的示例代码。使用这些框架的一个好处是咱们不须要直接处理图形和会话
Estimator
Estimator(评估器)类表明一个模型,以及这些模型被训练和评估的方式。咱们能够这样构建一个评估器:
为了构建一个 Estimator,咱们须要传递一个模型函数,一个参数集合以及一些配置。
  • 参数应该是模型超参数的集合,它能够是一个字典,但咱们将在本示例中将其表示为 HParams 对象,用做 namedtuple。
  • 该配置指定如何运行训练和评估,以及如何存出结果。这些配置经过 RunConfig 对象表示,该对象传达 Estimator 须要了解的关于运行模型的环境的全部内容。
  • 模型函数是一个 Python 函数,它构建了给定输入的模型(见后文)。
模型函数
模型函数是一个 Python 函数,它做为第一级函数传递给 Estimator。稍后咱们就会看到,TensorFlow 也会在其余地方使用第一级函数。模型表示为函数的好处在于模型能够经过实例化函数不断从新构建。该模型能够在训练过程当中被不一样的输入不断建立,例如:在训练期间运行验证测试。
模型函数将输入特征做为参数,相应标签做为张量。它还有一种模式来标记模型是否正在训练、评估或执行推理。模型函数的最后一个参数是超参数的集合,它们与传递给 Estimator 的内容相同。模型函数须要返回一个 EstimatorSpec 对象——它会定义完整的模型。
EstimatorSpec 接受预测,损失,训练和评估几种操做,所以它定义了用于训练,评估和推理的完整模型图。因为 EstimatorSpec 采用常规 TensorFlow Operations,所以咱们可使用像 TF-Slim 这样的框架来定义本身的模型。
Experiment
Experiment(实验)类是定义如何训练模型,并将其与 Estimator 进行集成的方式。咱们能够这样建立一个实验类:
Experiment 做为输入:
  • 一个 Estimator(例如上面定义的那个)。
  • 训练和评估数据做为第一级函数。这里用到了和前述模型函数相同的概念,经过传递函数而非操做,若有须要,输入图能够被重建。咱们会在后面继续讨论这个概念。
  • 训练和评估钩子(hooks)。这些钩子能够用于监视或保存特定内容,或在图形和会话中进行一些操做。例如,咱们将经过操做来帮助初始化数据加载器。
  • 不一样参数解释了训练时间和评估时间。
  • 一旦咱们定义了 experiment,咱们就能够经过 learn_runner.run 运行它来训练和评估模型:
与模型函数和数据函数同样,函数中的学习运算符将建立 experiment 做为参数。
Dataset
咱们将使用 Dataset 类和相应的 Iterator 来表示咱们的训练和评估数据,并建立在训练期间迭代数据的数据馈送器。在本示例中,咱们将使用 TensorFlow 中可用的 MNIST 数据,并在其周围构建一个 Dataset 包装器。例如,咱们把训练的输入数据表示为:

调用这个 get_train_inputs 会返回一个一级函数,它在 TensorFlow 图中建立数据加载操做,以及一个 Hook 初始化迭代器。
本示例中,咱们使用的 MNIST 数据最初表示为 Numpy 数组。咱们建立一个占位符张量来获取数据,再使用占位符来避免数据被复制。接下来,咱们在 from_tensor_slices 的帮助下建立一个切片数据集。咱们将确保该数据集运行无限长时间(experiment 能够考虑 epoch 的数量),让数据获得清晰,并分红所需的尺寸。
为了迭代数据,咱们须要在数据集的基础上建立迭代器。由于咱们正在使用占位符,因此咱们须要在 NumPy 数据的相关会话中初始化占位符。咱们能够经过建立一个可初始化的迭代器来实现。建立图形时,咱们将建立一个自定义的 IteratorInitializerHook 对象来初始化迭代器:
IteratorInitializerHook 继承自 SessionRunHook。一旦建立了相关会话,这个钩子就会调用 call after_create_session,并用正确的数据初始化占位符。这个钩子会经过 get_train_inputs 函数返回,并在建立时传递给 Experiment 对象。
train_inputs 函数返回的数据加载操做是 TensorFlow 操做,每次评估时都会返回一个新的批处理。
运行代码
如今咱们已经定义了全部的东西,咱们能够用如下命令运行代码:
若是你不传递参数,它将使用文件顶部的默认标志来肯定保存数据和模型的位置。训练将在终端输出全局步长、损失、精度等信息。除此以外,实验和估算器框架将记录 TensorBoard 能够显示的某些统计信息。若是咱们运行:
咱们就能够看到全部训练统计数据,如训练损失、评估准确性、每步时间和模型图。
评估精度在 TensorBoard 中的可视化
在 TensorFlow 中,有关 Estimator、Experiment 和 Dataset 框架的示例不多,这也是本文存在的缘由。但愿这篇文章能够向你们介绍这些架构工做的原理,它们应该采用哪些抽象方法,以及如何使用它们。若是你对它们很感兴趣,如下是其余相关文档。
关于 Estimator、Experiment 和 Dataset 的注释
  • 论文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:https://terrytangyuan.github.io/data/papers/tf-estimators-kdd-paper.pdf
  • Using the Dataset API for TensorFlow Input Pipelines:https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets
  • tf.estimator.Estimator:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
  • tf.contrib.learn.RunConfig:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig
  • tf.estimator.DNNClassifier:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier
  • tf.estimator.DNNRegressor:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor
  • Creating Estimators in tf.estimator:https://www.tensorflow.org/extend/estimators
  • tf.contrib.learn.Head:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Head
  • 本文用到的 Slim 框架:https://github.com/tensorflow/models/tree/master/slim
完整示例









推理训练模式
在训练模型后,咱们能够运行 estimateator.predict 来预测给定图像的类别。可以使用如下代码示例。



原文连接:https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0

选自Medium
做者:Peter Roelants
机器之心编译
参与:李泽南、黄小天

本文为机器之心编译,转载请联系本公众号得到受权。
相关文章
相关标签/搜索