关于Tensorflow读取数据,官网给出了三种方法:python
对于数据量较小而言,可能通常选择直接将数据加载进内存,而后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,你们本身尝试一下吧,我就不赘述了)。可是,若是数据量较大,这样的方法就不适用了,由于太耗内存,因此这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,好比csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即便用tensorflow内定标准格式——TFRecords网络
TFRecords
TFRecords实际上是一种二进制文件,虽然它不如其余格式好理解,可是它能更好的利用内存,更方便复制和移动,而且不须要单独的标签文件。函数
TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。咱们能够写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 而且经过tf.python_io.TFRecordWriter 写入到TFRecords文件。ui
从TFRecords文件中读取数据, 可使用tf.TFRecordReader的tf.parse_single_example解析器。这个操做能够将Example协议内存块(protocol buffer)解析为张量。spa
存入TFRecords文件须要数据先存入名为example的protocol buffer,而后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64。.net
咱们使用tf.train.Example
来定义咱们要填入的数据格式,而后使用tf.python_io.TFRecordWriter
来写入。线程
writer = tf.python_io.TFRecordWriter(out_name) #对每条数据分别得到文档,问题,答案三个值,并将相应单词转化为索引 #调用Example和Features函数将数据格式化保存起来。注意Features传入的参数应该是一个字典,方便后续读数据时的操做 example = tf.train.Example( features = tf.train.Features( feature = { 'document': tf.train.Feature( int64_list=tf.train.Int64List(value=document)), 'query': tf.train.Feature( int64_list=tf.train.Int64List(value=query)), 'answer': tf.train.Feature( int64_list=tf.train.Int64List(value=answer)) })) #写数据 serialized = example.SerializeToString() writer.write(serialized)
也能够用extend的方式:code
example = tf.train.Example() example.features.feature["context"].int64_list.value.extend(context_transformed)
example.features.feature["utterance"].int64_list.value.extend(utterance_transformed) example.features.feature["context_len"].int64_list.value.extend([context_len]) example.features.feature["utterance_len"].int64_list.value.extend([utterance_len]) writer = tf.python_io.TFRecordWriter(output_filename) writer.write(example.SerializeToString()) writer.close()
读取tfrecords文件orm
首先用tf.train.string_input_producer
读取tfrecords文件的list创建FIFO序列,能够申明num_epoches和shuffle参数表示须要读取数据的次数以及时候将tfrecords文件读入顺序打乱,而后定义TFRecordReader读取上面的序列返回下一个record,用tf.parse_single_example
对读取到TFRecords文件进行解码,根据保存的serialize example和feature字典返回feature所对应的值。此时得到的值都是string,须要进一步解码为所需的数据类型。把图像数据的string reshape成原始图像后能够进行preprocessing操做。此外,还能够经过tf.train.batch
或者tf.train.shuffle_batch
将图像生成batch序列。blog
因为tf.train
函数会在graph中增长tf.train.QueueRunner
类,而这些类有一系列的enqueue选项使一个队列在一个线程里运行。为了填充队列就须要用tf.train.start_queue_runners
来为全部graph中的queue runner启动线程,而为了管理这些线程就须要一个tf.train.Coordinator
来在合适的时候终止这些线程。
由于在读取数据以后咱们可能还会进行一些额外的操做,使咱们的数据格式知足模型输入,因此这里会引入一些额外的函数来实现咱们的目的。这里介绍几个我的感受较重要经常使用的函数。不过仍是推荐到官网API去查,或者有某种需求的时候到Stack Overflow上面搜一搜,通常都能找到知足本身需求的函数。
1,string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)
其输出是一个输入管道的队列,这里须要注意的参数是num_epochs和shuffle。对于每一个epoch其会将全部的文件添加到文件队列当中,若是设置shuffle,则会对文件顺序进行打乱。其对文件进行均匀采样,而不会致使上下采样。
2,shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
产生随机打乱以后的batch数据
3,sparse_ops.serialize_sparse(sp_input, name=None): 返回一个字符串的3-vector(1-D的tensor),分别表示索引、值、shape
4,deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): 将多个稀疏的serialized_sparse合并成一个
基本的,一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List
就这样,咱们把相关的信息都存到了一个文件中,并且读取也很方便。
简单的读取小例子
for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"): example = tf.train.Example() example.ParseFromString(serialized_example) context = example.features.feature['context'].int64_list.value utterance = example.features.feature['utterance'].int64_list.value
一旦生成了TFRecords文件,为了高效地读取数据,TF中使用队列(queue
)读取数据。
def read_and_decode(filename): #根据文件名生成一个队列 filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw' : tf.FixedLenFeature([], tf.string), }) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return img, label
以后咱们能够在训练的时候这样使用
img, label = read_and_decode("train.tfrecords") #使用shuffle_batch能够随机打乱输入 img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=2000, min_after_dequeue=1000) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init)
# 这是填充队列的指令,若是不执行程序会等在队列文件的读取处没法运行 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(3): val, l= sess.run([img_batch, label_batch]) #咱们也能够根据须要对val, l进行处理 #l = to_categorical(l, 12) print(val.shape, l)
注意:
第一,tensorflow里的graph可以记住状态(state),这使得TFRecordReader可以记住tfrecord的位置,而且始终能返回下一个。而这就要求咱们在使用以前,必须初始化整个graph,这里咱们使用了函数tf.initialize_all_variables()来进行初始化。
第二,tensorflow中的队列和普通的队列差很少,不过它里面的operation和tensor都是符号型的(symbolic),在调用sess.run()时才执行。
第三, TFRecordReader会一直弹出队列中文件的名字,直到队列为空。
record reader
解析tfrecord文件batcher
)QueueRunner
参考:
https://blog.csdn.net/u012759136/article/details/52232266
https://blog.csdn.net/liuchonge/article/details/73649251