1,tf-data两个新的抽象类python
dataset表示一系列元素,其中每一个元素包含一个或多个 Tensor
对象算法
建立来源(例如 Dataset.from_tensor_slices()
),以经过一个或多个 tf.Tensor
对象构建数据集。api
应用转换(例如 Dataset.batch()
),以经过一个或多个 tf.data.Dataset
对象构建数据集数组
iterator提供了从数据集中提取元素的主要方法。app
Iterator.get_next()
返回的操做会在执行时生成 Dataset
的下一个元素,而且此操做一般充当输入管道代码和模型之间的接口。最简单的迭代器是“单次迭代器”,它与特定的 Dataset
相关联,并对其进行一次迭代。要实现更复杂的用途,您能够经过 Iterator.initializer
操做使用不一样的数据集从新初始化和参数化迭代器框架
2,基本机制dom
2.1,定义来源分布式
要经过内存中的某些张量构建 Dataset
,您可使用 tf.data.Dataset.from_tensors()
或 tf.data.Dataset.from_tensor_slices()
。或者,若是输入数据以推荐的 TFRecord 格式存储在磁盘上,那么您能够构建 tf.data.TFRecordDataset
ide
2.2,有了 Dataset
对象,能够将其转换为新的 Dataset
函数
方法是连接tf.data.Dataset
对象上的方法调用。例如,您能够应用单元素转换,例如 Dataset.map()
(为每一个元素应用一个函数),也能够应用多元素转换(例如 Dataset.batch()
)
2.3,消耗 Dataset
中值的最多见方法是构建迭代器对象。
经过此对象,能够一次访问数据集中的一个元素(例如经过调用 Dataset.make_one_shot_iterator()
)。tf.data.Iterator
提供了两个操做:Iterator.initializer
,您能够经过此操做(从新)初始化迭代器的状态;以及 Iterator.get_next()
,此操做返回对应于有符号下一个元素的 tf.Tensor
对象
3,数据集结构
一个数据集包含多个元素,每一个元素的结构都相同。一个元素包含一个或多个 tf.Tensor
对象,这些对象称为组件。每一个组件都有一个 tf.DType
,表示张量中元素的类型;以及一个 tf.TensorShape
,表示每一个元素(可能部分指定)的静态形状。您能够经过 Dataset.output_types
和 Dataset.output_shapes
属性检查数据集元素各个组件的推理类型和形状
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
4,Dataset
转换
Dataset
转换支持任何结构的数据集。在使用 Dataset.map()
、Dataset.flat_map()
和 Dataset.filter()
转换时(这些转换会对每一个元素应用一个函数),元素结构决定了函数的参数.
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
5,建立迭代器
单次:
迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代,不须要显式初始化。单次迭代器能够处理基于队列的现有输入管道支持的几乎全部状况,但它们不支持参数化
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
可初始化:
您须要先运行显式 iterator.initializer
操做,而后才能使用可初始化迭代器.它容许您使用一个或多个 tf.placeholder()
张量(可在初始化迭代器时馈送)参数化数据集的定义max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
可从新初始化:
迭代器能够经过多个不一样的 Dataset
对象进行初始化.这些对象具备相同的结构(即每一个组件具备相同类型和兼容形状)
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
可馈送
迭代器能够与 tf.placeholder
一块儿使用,以选择所使用的 Iterator
(在每次调用 tf.Session.run
时)(经过熟悉的 feed_dict
机制)。它提供的功能与可从新初始化迭代器的相同,但在迭代器之间切换时不须要从数据集的开头初始化迭代器.tf.data.Iterator.from_string_handle
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle
6,消耗迭代器中的值
Iterator.get_next()
方法返回一个或多个 tf.Tensor
对象,这些对象对应于迭代器有符号的下一个元素。每次评估这些张量时,它们都会获取底层数据集中下一个元素的值。(请注意,与 TensorFlow 中的其余有状态对象同样,调用 Iterator.get_next()
并不会当即使迭代器进入下个状态。您必须在 TensorFlow 表达式中使用此函数返回的 tf.Tensor
对象,并将该表达式的结果传递到 tf.Session.run()
,以获取下一个元素并使迭代器进入下个状态。)
若是迭代器到达数据集的末尾,则执行 Iterator.get_next()
操做会产生 tf.errors.OutOfRangeError
。在此以后,迭代器将处于不可用状态;若是须要继续使用,则必须对其从新初始化
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break
若是数据集的每一个元素都具备嵌套结构,则 Iterator.get_next()
的返回值将是一个或多个 tf.Tensor
对象,这些对象具备相同的嵌套结构:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
Iterator.get_next()tf.Tensor
请注意,next1
、next2
和 next3
是由同一个操做/节点(经过 Iterator.get_next()
建立)生成的张量。所以,评估其中任何一个张量都会使全部组件的迭代器进入下个状态。典型的迭代器消耗方会在一个表达式中包含全部组件
next1next2next3Iterator.get_next()
tf.contrib.data.make_saveable_from_iterator
函数经过迭代器建立一个 SaveableObject
,该对象可用于保存和恢复迭代器(其实是整个输入管道)的当前状态。以这种方式建立的可保存对象能够添加到 tf.train.Saver
变量列表或 tf.GraphKeys.SAVEABLE_OBJECTS
集合中,以便采用与 tf.Variable
相同的方式进行保存和恢复。请参阅保存和恢复,详细了解如何保存和恢复变量。
# Create saveable object from iterator.
saveable = tf.contrib.data.make_saveable_from_iterator(iterator)
# Save the iterator state by adding it to the saveable objects collection.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable)
saver = tf.train.Saver()
with tf.Session() as sess:
if should_checkpoint:
saver.save(path_to_checkpoint)
# Restore the iterator state.
with tf.Session() as sess:
saver.restore(sess, path_to_checkpoint)
8,读取输入数据
8.1,消耗 NumPy 数组
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
请注意,上面的代码段会将 features
和 labels
数组做为 tf.constant()
指令嵌入在 TensorFlow 图中。这样很是适合小型数据集,但会浪费内存,由于会屡次复制数组的内容,并可能会达到 tf.GraphDef
协议缓冲区的 2GB 上限。
做为替代方案,您能够根据 tf.placeholder()
张量定义 Dataset
,并在对数据集初始化 Iterator
时馈送 NumPy 数组。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
8.2,消耗 TFRecord 数据
tf.data
API 支持多种文件格式,所以您能够处理那些不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种面向记录的简单二进制格式,不少 TensorFlow 应用采用此格式来训练数据。经过 tf.data.TFRecordDataset
类,您能够将一个或多个 TFRecord 文件的内容做为输入管道的一部分进行流式传输
初始化程序的 参数能够是字符串、字符串列表,也能够是字符串 。所以,若是您有两组分别用于训练和验证的文件,则可使用 来表示文件名,并使用适当的文件名初始化迭代器:# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
TFRecordDatasetfilenamestf.Tensortf.placeholder(tf.string)
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
8.3,消耗文本数据
默认状况下, 会生成每一个文件的每一行,这多是不可取的(例如,若是文件以标题行开头或包含注释)。可使用 和 转换来移除这些行。为了将这些转换分别应用于每一个文件,咱们使用 为每一个文件建立一个嵌套的 。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
TextLineDatasetDataset.skip()Dataset.filter()Dataset.flat_map()Dataset
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
8.4,消耗 CSV 数据
给定一个或多个文件名以及默认值列表后,CsvDataset
将生成一个元素元组,元素类型对应于为每一个 CSV 记录提供的默认元素类型
# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8 # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values
record_defaults = [[0.0]] * 8
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)
# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [[0.0]] * 2 # Only provide defaults for the selected columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2,4])
9,使用 Dataset.map()
预处理数据
Dataset.map(f)
转换经过将指定函数 f
应用于输入数据集的每一个元素来生成新数据集
tf.Example
协议缓冲区消息
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
tf.py_func()
应用任意 Python 逻辑为了确保性能,咱们建议您尽量使用 TensorFlow 指令预处理数据。不过,在解析输入数据时,调用外部 Python 库有时颇有用。为此,请在 Dataset.map()
转换中调用 tf.py_func()
指令
import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)
最简单的批处理形式是将数据集中的 n
个连续元素堆叠为一个元素。Dataset.batch()
转换正是这么作的,它与 tf.stack()
运算符具备相同的限制(被应用于元素的每一个组件):即对于每一个组件 i,全部元素的张量形状都必须彻底相同
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]
您能够经过 Dataset.padded_batch()
转换为每一个组件的每一个维度设置不一样的填充,而且能够采用可变长度(在上面的示例中用 None
表示)或恒定长度。也能够替换填充值,默认设置为 0
10,训练工做流程
要迭代数据集多个周期,最简单的方法是使用 Dataset.repeat()
转换
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
应用不带参数的 Dataset.repeat()
转换将无限次地重复输入。Dataset.repeat()
转换将其参数链接起来,而不会在一个周期结束和下一个周期开始时发出信号。
若是您想在每一个周期结束时收到信号,则能够编写在数据集结束时捕获 tf.errors.OutOfRangeError
的训练循环。此时,您能够收集关于该周期的一些统计信息(例如验证错误)
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Compute for 100 epochs.
for _ in range(100):
sess.run(iterator.initializer)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
# [Perform end-of-epoch calculations here.]
Dataset.shuffle()
转换会使用相似于 tf.RandomShuffleQueue
的算法随机重排输入数据集:它会维持一个固定大小的缓冲区,并从该缓冲区统一地随机选择下一个元素
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
11,使用高阶 API
tf.train.MonitoredTrainingSession
API 简化了在分布式设置下运行 TensorFlow 的不少方面。MonitoredTrainingSession
使用 tf.errors.OutOfRangeError
表示训练已完成,所以要将其与 tf.data
API 结合使用,咱们建议使用 Dataset.make_one_shot_iterator()
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:
while not sess.should_stop():
sess.run(training_op)
要在 input_fn
中使用 Dataset
(input_fn 属于 tf.estimator.Estimator
),只需返回 Dataset
便可,框架将负责为您建立和初始化迭代器。例如:
def dataset_input_fn():
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.image.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [299, 299, 1])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
# Each element of `dataset` is tuple containing a dictionary of features
# (in which each value is a batch of values for that feature), and a batch of
# labels.
return dataset