一直以来都是用 tensorflow 框架实现深度学习算法和实验,在网络训练时有一个重要的问题就是训练数据的读取。tensorflow 支持流水线并行读取数据,这种方式将数据的读取和网络训练并行,数据读取效率和将全部数据载入内存后进行存取至关,却又不会增长内存开销,是很值得推荐的一种方式。这篇笔记就是总结一下本身在实际应用中的并行数据读取,留个备份,随时学习。git
主要参考了 Google HDRnet 代码:https://github.com/mgharbi/hdrnet,CycleGAN 代码:https://github.com/vanhuyz/CycleGAN-TensorFlowgithub
HDRnet工程里的 data_pipeline.py 文件提供了很是清晰的流水线读取数据示例,在官方代码的基础上,能够很轻松地针对本身的应用实现一套数据读取接口,假设咱们的训练数据存储在目录 training_data/input 和 training_data/output,input 存储网络训练输入,output 存储网络目标输出,一对训练样本的输入和目标输出名称相同,均为二进制文件 *.dat,如下面代码为示例展现如何实现流水线并行数据读取:算法
def data_generator(params, data_path): filelist = os.listdir(data_path) # 获取训练目录下的文件名列表 if params.shuffle: random.shuffle(filelist) # 随机打乱训练数据 input_files = [os.path.join(data_path, 'input', f) for f in filelist if f.endswith('.dat')] # 生成输入数据文件名列表 output_files = [os.path.join(data_path, 'output', f) for f in filelist if f.endswith('.dat')] # 生成目标输出文件名列表
# 基于给定的文件名列表,建立先入先出的文件名队列,输入能够是多个文件名列表,输出对应的对个文件名队列 input_queue, output_queue = tf.train.slice_input_producer( [input_files, output_files], shuffle=params.shuffle, seed=params.seed, num_epochs=params.num_epochs) input_reader = tf.read_file(input_queue) # 建立 reader,读取输入数据 output_reader = tf.read_file(output_queue) # 建立 reader,读取目标输出
# 根据文件类型的不一样解析数据,若是文件是图像,可使用 tf.image.decode_jpeg 等函数解析 if os.path.splitext(input_files[0])[-1] == '.jpg': input = tf.image.decode_jpeg(input_reader, channels=3) else: input = tf.decode_raw(input_reader, data_type=tf.uint16) # 若是是二进制信息存储,则可使用 tf.decode_raw 函数解析 input = tf.reshape(input, [params.height, params.width, params.channel]) # 将数据 reshape 为正确的形状,此处以图像 (height, width, channel) 为例 if os.path.splitext(output_files[0])[-1] == '.jpg': output = tf.image.decode_jpeg(output_reader, channels=3) else: output = tf.decode_raw(output_reader, data_type=tf.uint16) input = tf.reshape(input, [params.height, params.width, params.channel])
# 上面读取了单个输入和对应的目标输出,网络训练时如需数据增广,能够在读取单个训练对以后,使用函数对数据进行处理,扩大训练集 input, output = augment_data(input, output) samples = {} # 将增广后的一对训练数据组织为字典的形式,便于后面组织成 batch samples['input'] = input samples['output'] = output if param.shuffle: # 建立批样例训练数据 samples = tf.train.shuffle_batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity, min_after_dequeue=params.min_after_dequeue) else: samples = tf.train.batch( sample, batch_size=params.batch_size, num_threads=params.nthreads, capacity=params.capacity) return samples # 返回一个 batch 的训练数据
代码中具体函数的接口能够经过 tensorflow 的文档查清。以上,只是声明了多线程的文件读取操做,并不会真正的读取数据,为了在会话执行时顺利地获取输入数据,须要使用 tf.train.start_queue_runners 来启动执行入队列操做的全部线程,具体过程包括:文件名入队到文件名队列,样例入队到样例队列。示例代码以下:安全
params.shuffle = true params.seed = 1234 params.height = 224 params.width = 224 params.channel = 3 training_path = 'dir/to/training/data' training_samples = data_generator(params, training_path) batch_inputs = training_samples['input'] batch_outputs = training_sample['output']
# 网络计算图建立
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess = sess)
sess.run(train_op)
...
上面的代码中,输入输出各只有一张图像,展现了如何实现流水线读取,以及如何使用读取出的数据。当输入或者输出包含多个文件时,例如,输入是图像和其语义分割图,能够在 data_generator 函数中,增长对语义分割图的读取,相对应的,多了 seg_files、seg_queue、seg_reader、seg_map 以及最后的 samples['seg_map'] = seg_map;网络
一样,当输入数据是其它格式时,只须要根据对应的格式修改数据读取的代码接,例如 CycleGAN 中,训练数据存储为 tfrecord 格式,须要修改的其实就是对文件的读取部分。多线程
咱们都知道,tensorflow 在建立网络计算图时,一般须要为网络输入和目标输出先声明 placeholder,可是上面的第二段示例代码则是直接使用数据读取的输出构建网络计算图,是否是说采用这种方式就不能采用常见方法那样,先定义 placeholder,再在网络训练中使用 feed_dict 填充数据呢?答案是能够的,方法也和一般的作法没有太大区别,示例以下:框架
x = tf.placeholder(...)
y = tf.palceholder(...) conv_1 = Conv2D(y, ...) ...
loss = tf.reduce_mean(tf.squared_difference(net_y, y))
train_op = tf.minimize(loss, ...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
samples = data_generator(params, training_path)
sess.run(train_op, feed_dict={x: samples['input'], y: samples['output']})
和第一种方法的区别是 data_generator 是在会话 sess 中调用,而不是在构建网络计算图时调用;dom
须要注意的是,上面的方式容错性比较差,主要是由于采用多线程方式读取数据,队列操做后台线程的生命周期无管理机制,线程出现异常会致使程序崩溃,比较常见的异常是文件名队列或者样例队列越界抛出的 tf.errors.OutOfRangeError。为了处理这种异常,HDRnet、CycleGAN 工程代码中都使用 tf.train.Coordinator 建立了管理多线程声明周期的协调器,其工做原理是经过监控 tensorflow 全部后台线程,当有线程出现异常时,协调器的 should_stop 成员方法返回 True,循环结束,而后会话执行协调器的 request_stop 方法,请求全部线程安全退出。一套完整的示例代码以下:函数
params.shuffle = true
params.seed = 1234
params.height = 224
params.width = 224
params.channel = 3
training_path = 'dir/to/training/data'
training_samples = data_generator(params, training_path)
batch_inputs = training_samples['input']
batch_outputs = training_sample['output']
# 网络计算图建立
conv_1 = Conv2D(batch_inputs, ...)
...
conv_n = Conv2D(conv_n-1, ...)
output = tf.sigmoid(conv_n)
loss = tf.reduce_mean(tf.squared_difference(output, batch_outputs))
train_op = tf.minimize(loss,...)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
sess.run(train_op)
except KeyboardInterrupt: # 响应 Ctrl+C 中止训练
coord.request_stop()
except Exception as e: # 后台线程出现异常
coord.request_stop(e)
finally: # 这一步总会执行
save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step) # 保存 checkpoint
coord.request_stop()
coord.join(threads)
以上,介绍 tensorflow 中如何使用多线程并行读取数据,如何在训练中使用读取的数据,以及如何对多线程进行监视,提高网络训练的容错性。分享给你们,也给本身学习。学习