tensorflow 使用tf.data.TFRecordDataset()读取tfrecord文件 许多输入管道都从 TFRecord 格式的文件中提取 tf.train.Example 协议缓冲区消息(例如这种文件使用 tf.python_io.TFRecordWriter 编写而成)。每一个 tf.train.Example 记录都包含一个或多个“特征”,输入管道一般会将这些特征转换为张量。python
def input_layer():
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_example) # Parse the record into tensors.
#dataset = dataset.repeat() # Repeat the input indefinitely.
iterator = dataset.make_initializable_iterator()
return filenames, iterator,iterator.get_next()
复制代码
def _parse_example(example):
keys_to_feature = {'img_query': tf.FixedLenFeature((), tf.string),
'img_positive': tf.FixedLenFeature((), tf.string),
'img_negative': tf.FixedLenFeature((), tf.string)
}
feat_tensor_maps = tf.parse_single_example(example, keys_to_feature)
def _process_img(img_bytes):
img = tf.image.decode_jpeg(img_bytes)
img = tf.div(tf.cast(img,tf.float32),255.0)
return img
img_query = _process_img(feat_tensor_maps['img_query'])
img_positive = _process_img(feat_tensor_maps['img_positive'])
img_negative = _process_img(feat_tensor_maps['img_negative'])
return img_query, img_positive, img_negative
复制代码
tfrecord_path = '/media/ubuntu/FED8DCB6D8DC6E81/stuff/deep_ranking_tfrecord/train.record'
filenames_tensor, iterator,ele_tensor = input_layer()
with tf.Session() as sess:
sess.run(iterator.initializer,feed_dict={filenames_tensor: [tfrecord_path]})
img_eval, = sess.run([ele_tensor], feed_dict={filenames_tensor: [tfrecord_path]})
img_query, img_positive, img_negative=img_eval
plt.imshow(img_query)
plt.show()
plt.imshow(img_positive)
plt.show()
plt.imshow(img_negative)
plt.show()
复制代码