Tensorflow读取文件到队列文件

TensorFlow读取二进制文件数据到队列
2016-11-03 09:30:00       0 个评论    来源: diligent_321的博客  
收藏    我要投稿

TensorFlow是一种符号编程框架(与theano相似),先构建数据流图再输入数据进行模型训练。Tensorflow支持不少种样例输入的方式。最容易的是使用placeholder,但这须要手动传递numpy.array类型的数据。第二种方法就是使用二进制文件和输入队列的组合形式。这种方式不只节省了代码量,避免了进行data augmentation和读文件操做,能够处理不一样类型的数据, 并且也再也不须要人为地划分开“预处理”和“模型计算”。在使用TensorFlow进行异步计算时,队列是一种强大的机制。php

队列使用概述html

正如TensorFlow中的其余组件同样,队列就是TensorFlow图中的节点。这是一种有状态的节点,就像变量同样:其余节点能够修改它的内容。具体来讲,其余节点能够把新元素插入到队列后端(rear),也能够把队列前端(front)的元素删除。队列,如FIFOQueue和RandomShuffleQueue(A queue implementation that dequeues elements in a random order.)等对象,在TensorFlow的tensor异步计算时都很是重要。例如,一个典型的输入结构是使用一个RandomShuffleQueue来做为模型训练的输入,多个线程准备训练样本,而且把这些样本压入队列,一个训练线程执行一个训练操做,此操做会从队列中移除最小批次的样本(mini-batches),这种结构具备许多优势。前端

TensorFlow的Session对象是能够支持多线程的,所以多个线程能够很方便地使用同一个会话(Session)而且并行地执行操做。然而,在Python程序实现这样的并行运算却并不容易。全部线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。所幸TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一块儿使用。Coordinator类能够用来同时中止多个工做线程而且向那个在等待全部工做线程终止的程序报告异常。QueueRunner类用来协调多个工做线程同时将多个tensor压入同一个队列中。java

(1)读二进制文件数据到队列中python

同不少其余的深度学习框架同样,TensorFlow有它本身的二进制格式。它使用了a mixture of its Records 格式和protobuf。Protobuf是一种序列化数据结构的方式,给出了关于数据的一些描述。TFRecords是tensorflow的默认数据格式,一个record就是一个包含了序列化tf.train.Example 协议缓存对象的二进制文件,可使用python建立这种格式,而后即可以使用tensorflow提供的函数来输入给机器学习模型。
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
 
def read_and_decode_single_example(filename_queue):
# 定义一个空的类对象,相似于c语言里面的结构体定义
     class Image(self):
     pass
     image = Image()
     image.height = 32
     image.width = 32
     image.depth = 3
     label_bytes = 1
     
     Bytes_to_read = label_bytes+image.heigth*image.width* 3
     # A Reader that outputs fixed-length records from a file
     reader = tf.FixedLengthRecordReader(record_bytes=Bytes_to_read)
     # Returns the next record (key, value) pair produced by a reader, key 和value都是字符串类型的tensor
     # Will dequeue a work unit from queue if necessary (e.g. when the
     # Reader needs to start reading from a new file since it has
     # finished with the previous file).
     image.key, value_str = reader.read(filename_queue)
     # Reinterpret the bytes of a string as a vector of numbers,每个数值占用一个字节,在[ 0 , 255 ]区间内,所以out_type要取uint8类型
     value = tf.decode_raw(bytes=value_str, out_type=tf.uint8)
     # Extracts a slice from a tensor, value中包含了label和feature,故要对向量类型tensor进行 'parse' 操做
     image.label = tf.slice(input_=value, begin=[ 0 ], size=[ 1 ])
     value = value.slice(input_=value, begin=[ 1 ], size=[- 1 ]).reshape((image.depth, image.height, image.width))
     transposed_value = tf.transpose(value, perm=[ 2 , 0 , 1 ])
     image.mat = transposed_value
     return image
接下来咱们即可以调用这个函数了,
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
filenames =[os.path.join(data_dir, 'test_batch.bin' )]
# Output strings (e.g. filenames) to a queue for an input pipeline
filename_queue = tf.train.string_input_producer(string_tensor=filenames)
# returns symbolic label and image
img_obj = read_and_decode_single_example( "filename_queue" )
Label = img_obj.label
Image = img_obj.mat
sess = tf.Session()
# 初始化tensorflow图中的全部状态,如待读取的下一个记录tfrecord的位置,variables等
init = tf.initialize_all_variables()
sess.run(init)
tf.train.start_queue_runners(sess=sess)
# grab examples back.
# first example from file
label_val_1, image_val_1 = sess.run([label, image])
# second example from file
label_val_2, image_val_2 = sess.run([label, image])
值得一提的是,TFRecordReader老是做用于文件名队列。它将会从队列中弹出文件名并使用该文件名,直到tfrecord为空时中止,此时它将从文件名队列中弹出下一个filename。然而,文件名队列又是怎么得来的呢?起初这个队列是空的,QueueRunners的概念即源于此。QueueRunners本质上就是一个线程thread,这个线程负责使用会话session并不断地调用enqueue操做。Tensorflow把这个模式封装在tf.train.QueueRunner对象里面。入队列操做99%的时间均可以被忽略掉,由于这个操做是由后台负责运行。(好比在上面的例子中,tf.train.string_input_producer建立了一个这样的线程,添加QueueRunner到数据流图中)。

可想而知,在你运行任何训练步骤以前,咱们要告知tensorflow去启动这些线程,不然这些队列会由于等待数据入队而被堵塞,致使数据流图将一直处于挂起状态。咱们能够调用tf.train.start_queue_runners(sess=sess)来启动全部的QueueRunners。这个调用并非符号化的操做,它会启动输入管道的线程,填充样本到队列中,以便出队操做能够从队列中拿到样本。另外,必需要先运行初始化操做再建立这些线程。若是这些队列未被初始化,tensorflow会抛出错误。web

(2)从二进制文件中读取mini-batchs编程

在训练机器学习模型时,使用单个样例更新参数属于“online learning”,然而在线下环境下,咱们一般采用基于mini-batchs 随机梯度降低法(SGD),可是在tensorflow中如何利用queuerunners返回训练块数据呢?请参见下面的程序:
?
1
2
3
4
5
image_batch, label_batch = tf.train.shuffle_batch(tensor_list=[image, label]],
                                                   batch_size=batch_size,
                                                   num_threads= 24 ,
                                                   min_after_dequeue=min_samples_in_queue,
                                                   capacity=min_samples_in_queue+ 3 *batch_size)

 

读取batch数据须要使用新的队列queues和QueueRunners(大体流程图以下)。Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的(image,labels)对送入队列中,这个入队操做是经过QueueRunners启动另外的线程来完成的。这个RandomShuffleQueue会顺序地压样例到队列中,直到队列中的样例个数达到了batch_size+min_after_dequeue个。它而后从队列中选择batch_size个随机的元素进行返回。事实上,shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果。有了这个batches变量,就能够开始训练机器学习模型了。

\

 

函数 tf.train.shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, shared_name=None, name=None)的使用说明:
做用:Creates batches by randomly shuffling tensors.(从队列中随机筛选多个样例返回给image_batch和label_batch);
参数说明:
tensor_list: The list of tensors to enqueue.(待入队的tensor list);
batch_size: The new batch size pulled from the queue;
capacity: An integer. The maximum number of elements in the queue(队列长度);
min_after_dequeue: Minimum number elements in the queue after a dequeue, used to ensure a level of mixing of elements.(随机取样的样本整体最小值,用于保证所取mini-batch的随机性);
num_threads: The number of threads enqueuing `tensor_list`.(session会话支持多线程,这里能够设置多线程加速样本的读取)
seed: Seed for the random shuffling within the queue.
enqueue_many: Whether each tensor in `tensor_list` is a single example.(为False时表示tensor_list是一个样例,压入时占用队列中的一个元素;为True时表示tensor_list中的每个元素都是一个样例,压入时占用队列中的一个元素位置,能够看做为一个batch);
shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for `tensor_list`.
shared_name: (Optional) If set, this queue will be shared under the given name across multiple sessions.

name: (Optional) A name for the operations.后端

相关文章
相关标签/搜索