"笨方法"学习CNN图像识别(二)—— tfrecord格式高效读取数据

 

原文地址:https://finthon.com/learn-cnn-two-tfrecord-read-data/
-- 全文阅读5分钟 --python

在本文中,你将学习到如下内容:


  • 将图片数据制做成tfrecord格式
  • 将tfrecord格式数据还原成图片

前言

tfrecord是TensorFlow官方推荐的标准格式,可以将图片数据和标签一块儿存储成二进制文件,在TensorFlow中实现快速地复制、移动、读取和存储操做。训练网络的时候,经过创建队列系统,能够预先将tfrecord格式的数据加载进队列,队列会自动实现数据随机或有序地进出栈,而且队列系统和模型训练是独立进行的,这就加速了咱们模型的读取和训练。swift

准备图片数据

按照图片预处理教程,咱们得到了两组resize成224*224大小的商标图片集,把标签分别命名成1和2两类,以下图:网络

 
两类图片数据集

 

 
label:1

 

 
label:2


咱们如今就将这两个类别的图片集制做成tfrecord格式。数据结构

 

制做tfrecord格式

导入必要的库:app

import os from PIL import Image import tensorflow as tf 

定义一些路径和参数:函数

# 图片路径,两组标签都在该目录下 cwd = r"./brand_picture/" # tfrecord文件保存路径 file_path = r"./" # 每一个tfrecord存放图片个数 bestnum = 1000 # 第几个图片 num = 0 # 第几个TFRecord文件 recordfilenum = 0 # 将labels放入到classes中 classes = [] for i in os.listdir(cwd): classes.append(i) # tfrecords格式文件名 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) 

bestnum控制每一个tfrecord的大小,这里使用1000,首先定义tf.python_io.TFRecordWriter,方便后面写入存储数据。
制做tfrecord格式时,其实是将图片和标签一块儿存储在tf.train.Example中,它包含了一个字典,键是一个字符串,值的类型能够是BytesList,FloatList和Int64List。学习

for index, name in enumerate(classes): class_path = os.path.join(cwd, name) for img_name in os.listdir(class_path): num = num + 1 if num > bestnum: #超过1000,写入下一个tfrecord num = 1 recordfilenum += 1 ftrecordfilename = ("traindata_63.tfrecords-%.3d" % recordfilenum) writer = tf.python_io.TFRecordWriter(os.path.join(file_path, ftrecordfilename)) img_path = os.path.join(class_path, img_name) # 每个图片的地址 img = Image.open(img_path, 'r') img_raw = img.tobytes() # 将图片转化为二进制格式 example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), })) writer.write(example.SerializeToString()) # 序列化为字符串 writer.close() 

在这里咱们保存的label是classes中的编号索引,即0和1,你也能够改为文件名做为label,可是必定是int类型。图片读取之后转化成了二进制格式。最后经过writer写入数据到tfrecord中。
最终咱们在当前目录下生成一个tfrecord文件:ui

 
tfrecord文件

读取tfrecord文件

读取tfrecord文件是存储的逆操做,咱们定义一个读取tfrecord的函数,方便后面调用。spa

import tensorflow as tf def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) img = tf.cast(img, tf.float32) / 255.0 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) 

这段代码主要是经过tf.TFRecordReader读取里面的数据,而且还原数据类型,最后咱们对图片矩阵进行归一化。到这里咱们就完成了tfrecord输出,能够对接后面的训练网络了。
若是咱们想直接还原成原来的图片,就须要先注释掉读取tfrecord函数中的归一化一行,并添加部分代码,完整代码以下:线程

import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt def read_and_decode_tfrecord(filename): filename_deque = tf.train.string_input_producer(filename) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_deque) features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string)}) label = tf.cast(features['label'], tf.int32) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [224, 224, 3]) # img = tf.cast(img, tf.float32) / 255.0 #将矩阵归一化0-1之间 return img, label train_list = ['traindata_63.tfrecords-000'] img, label = read_and_decode_tfrecord(train_list) img_batch, label_batch = tf.train.batch([img, label], num_threads=2, batch_size=2, capacity=1000) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 建立一个协调器,管理线程 coord = tf.train.Coordinator() # 启动QueueRunner,此时文件名队列已经进队 threads = tf.train.start_queue_runners(sess=sess, coord=coord) while True: b_image, b_label = sess.run([img_batch, label_batch]) b_image = Image.fromarray(b_image[0]) plt.imshow(b_image) plt.axis('off') plt.show() coord.request_stop() # 其余全部线程关闭以后,这一函数才能返回 coord.join(threads) 

在后面创建了一个队列tf.train.batch,经过Session调用顺序队列系统,输出每张图片。Session部分在训练网络的时候还会讲到。咱们学习tfrecord过程,能加深对数据结构和类型的理解。到这里咱们对tfrecord格式的输入输出有了必定了解,咱们训练网络的准备工做已完成,接下来就是咱们CNN模型的搭建工做了。

可能感兴趣
相关文章
相关标签/搜索