TFrecord是一个Google提供的用于深度学习的数据格式,我的以为很方便规范,值得学习。本文主要讲的是怎么存储array,别的数据存储较为简单,触类旁通就行。python
在TFrecord中的数据都须要进行一个转化的过程,这个转化分红三种数据结构
通常来说咱们的图片读进来之后是两种形式,多线程
可是存储在TFrecord里面的不能是array的形式,因此咱们须要利用tostring()将上面的矩阵转化成字符串再经过tf.train.BytesList转化成能够存储的形式。学习
下面给个实例代码,你们看看就懂了ui
adjust_pic.py : 做用就是转化Image大小编码
# -*- coding: utf-8 -*- import tensorflow as tf def resize(img_data, width, high, method=0): return tf.image.resize_images(img_data,[width, high], method)
pic2tfrecords.py :将图片存成TFrecord.net
# -*- coding: utf-8 -*- # 将图片保存成 TFRecord import os.path import matplotlib.image as mpimg import tensorflow as tf import adjust_pic as ap from PIL import Image SAVE_PATH = 'data/dataset.tfrecords' def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def load_data(datafile, width, high, method=0, save=False): train_list = open(datafile,'r') # 准备一个 writer 用来写 TFRecord 文件 writer = tf.python_io.TFRecordWriter(SAVE_PATH) with tf.Session() as sess: for line in train_list: # 得到图片的路径和类型 tmp = line.strip().split(' ') img_path = tmp[0] label = int(tmp[1]) # 读取图片 image = tf.gfile.FastGFile(img_path, 'r').read() # 解码图片(若是是 png 格式就使用 decode_png) image = tf.image.decode_jpeg(image) # 转换数据类型 # 由于为了将图片数据可以保存到 TFRecord 结构体中,因此须要将其图片矩阵转换成 string,因此为了在使用时可以转换回来,这里肯定下数据格式为 tf.float32 image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 既然都将图片保存成 TFRecord 了,那就先把图片转换成但愿的大小吧 image = ap.resize(image, width, high) # 执行 op: image image = sess.run(image) # 将其图片矩阵转换成 string image_raw = image.tostring() # 将数据整理成 TFRecord 须要的数据结构 example = tf.train.Example(features=tf.train.Features(feature={ 'image_raw': _bytes_feature(image_raw), 'label': _int64_feature(label), })) # 写 TFRecord writer.write(example.SerializeToString()) writer.close() load_data('train_list.txt_bak', 224, 224)
tfrecords2data.py :读取Tfrecord里的内容线程
# -*- coding: utf-8 -*- # 从 TFRecord 中读取并保存图片 import tensorflow as tf import numpy as np SAVE_PATH = 'data/dataset.tfrecords' def load_data(width, high): reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer([SAVE_PATH]) # 从 TFRecord 读取内容并保存到 serialized_example 中 _, serialized_example = reader.read(filename_queue) # 读取 serialized_example 的格式 features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64), }) # 解析从 serialized_example 读取到的内容 images = tf.decode_raw(features['image_raw'], tf.uint8) labels = tf.cast(features['label'], tf.int64) with tf.Session() as sess: # 启动多线程 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 由于我这里只有 2 张图片,因此下面循环 2 次 for i in range(2): # 获取一张图片和其对应的类型 label, image = sess.run([labels, images]) # 这里特别说明下: # 由于要想把图片保存成 TFRecord,那就必须先将图片矩阵转换成 string,即: # pic2tfrecords.py 中 image_raw = image.tostring() 这行 # 因此这里须要执行下面这行将 string 转换回来,不然会没法 reshape 成图片矩阵,请看下面的小例子: # a = np.array([[1, 2], [3, 4]], dtype=np.int64) # 2*2 的矩阵 # b = a.tostring() # # 下面这行的输出是 32,即: 2*2 以后还要再乘 8 # # 若是 tostring 以后的长度是 2*2=4 的话,那能够将 b 直接 reshape([2, 2]),但如今的长度是 2*2*8 = 32,因此没法直接 reshape # # 同理若是你的图片是 500*500*3 的话,那 tostring() 以后的长度是 500*500*3 后再乘上一个数 # print len(b) # # 但在网上有不少提供的代码里都没有下面这一行,大家那真的能 reshape ? image = np.fromstring(image, dtype=np.float32) # reshape 成图片矩阵 image = tf.reshape(image, [224, 224, 3]) # 由于要保存图片,因此将其转换成 uint8 image = tf.image.convert_image_dtype(image, dtype=tf.uint8) # 按照 jpeg 格式编码 image = tf.image.encode_jpeg(image) # 保存图片 with tf.gfile.GFile('pic_%d.jpg' % label, 'wb') as f: f.write(sess.run(image)) load_data(224, 224)
以上代码摘自TFRecord 的使用,以为挺好的,没改原样照搬,我本身作实验时改了不少,由于我是在im2txt的基础上写的。code