直接采用矩阵方式创建数据集见:http://www.javashuo.com/article/p-pjxbpkmd-kr.htmlhtml
制做本身的数据集(使用tfrecords)python
为何采用这个格式?多线程
TFRecords文件格式在图像识别中有很好的使用,其能够将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它能够在模型进行训练以前经过预处理步骤将图像转换为TFRecords格式,此格式最大的优势实践每幅输入图像和与之关联的标签放在同一个文件中.TFRecords文件是一种二进制文件,其不对数据进行压缩,因此能够被快速加载到内存中.格式不支持随机访问,所以它适合于大量的数据流,但不适用于快速分片或其余非连续存取。函数
前戏:ui
tf.train.Feature
tf.train.Feature有三个属性为tf.train.bytes_list tf.train.float_list tf.train.int64_list,显然咱们只须要根据上一步获得的值来设置tf.train.Feature的属性就能够了,以下所示:spa
1 tf.train.Feature(int64_list=data_id) 2 tf.train.Feature(bytes_list=data)
tf.train.Features
从名字来看,咱们应该能猜出tf.train.Features是tf.train.Feature的复数,事实上tf.train.Features有属性为feature,这个属性的通常设置方法是传入一个字典,字典的key是字符串(feature名),而值是tf.train.Feature对象。所以,咱们能够这样获得tf.train.Features对象:.net
1 feature_dict = { 2 "data_id": tf.train.Feature(int64_list=data_id), 3 "data": tf.train.Feature(bytes_list=data) 4 } 5 features = tf.train.Features(feature=feature_dict)
tf.train.Example
终于到咱们的主角了。tf.train.Example有一个属性为features,咱们只须要将上一步获得的结果再次当作参数传进来便可。
另外,tf.train.Example还有一个方法SerializeToString()须要说一下,这个方法的做用是把tf.train.Example对象序列化为字符串,由于咱们写入文件的时候不能直接处理对象,须要将其转化为字符串才能处理。
固然,既然有对象序列化为字符串的方法,那么确定有从字符串反序列化到对象的方法,该方法是FromString(),须要传递一个tf.train.Example对象序列化后的字符串进去作为参数才能获得反序列化的对象。
在咱们这里,只须要构建tf.train.Example对象并序列化就能够了,这一步的代码为:线程
1 example = tf.train.Example(features=features) 2 example_str = example.SerializeToString()
实例(高潮部分):code
首先看一下咱们的文件夹路径:htm
create_tfrecords.py中写咱们的函数
生成数据文件阶段代码以下:
1 def creat_tf(imgpath): 2 cwd = os.getcwd() #获取当前路径 3 classes = os.listdir(cwd + imgpath) #获取到[1, 2]文件夹 4 # 此处定义tfrecords文件存放 5 writer = tf.python_io.TFRecordWriter("train.tfrecords") 6 for index, name in enumerate(classes): #循环获取俩文件夹(俩类别) 7 class_path = cwd + imgpath + name + "/" 8 if os.path.isdir(class_path): 9 for img_name in os.listdir(class_path): 10 img_path = class_path + img_name 11 img = Image.open(img_path) 12 img = img.resize((224, 224)) 13 img_raw = img.tobytes() 14 example = tf.train.Example(features=tf.train.Features(feature={ 15 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])), 16 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) 17 })) 18 writer.write(example.SerializeToString()) 19 print(img_name) 20 writer.close()
这段代码主要生成 train.tfrecords 文件。
读取数据阶段代码以下:
1 def read_and_decode(filename): 2 # 根据文件名生成一个队列 3 filename_queue = tf.train.string_input_producer([filename]) 4 5 reader = tf.TFRecordReader() 6 _, serialized_example = reader.read(filename_queue) # 返回文件名和文件 7 features = tf.parse_single_example(serialized_example, 8 features={ 9 'label': tf.FixedLenFeature([], tf.int64), 10 'img_raw': tf.FixedLenFeature([], tf.string), 11 }) 12 13 img = tf.decode_raw(features['img_raw'], tf.uint8) 14 img = tf.reshape(img, [224, 224, 3]) 15 # 转换为float32类型,并作归一化处理 16 img = tf.cast(img, tf.float32) # * (1. / 255) 17 label = tf.cast(features['label'], tf.int64) 18 return img, label
训练阶段咱们获取数据的代码:
1 images, labels = read_and_decode('./train.tfrecords') 2 img_batch, label_batch = tf.train.shuffle_batch([images, labels], 3 batch_size=5, 4 capacity=392, 5 min_after_dequeue=200) 6 init = tf.global_variables_initializer() 7 with tf.Session() as sess: 8 sess.run(init) 9 coord = tf.train.Coordinator() #线程协调器 10 threads = tf.train.start_queue_runners(sess=sess,coord=coord) 11 # 训练部分代码-------------------------------- 12 IMG, LAB = sess.run([img_batch, label_batch]) 13 print(IMG.shape) 14 15 #---------------------------------------------- 16 coord.request_stop() # 协调器coord发出全部线程终止信号 17 coord.join(threads) #把开启的线程加入主线程,等待threads结束
总结(流程):
record reader
解析tfrecord文件batcher
)QueueRunner
备注:关于tf.train.Coordinator 详见:
https://blog.csdn.net/dcrmg/article/details/79780331
TensorFlow的Session对象是支持多线程的,能够在同一个会话(Session)中建立多个线程,并行执行。在Session中的全部线程都必须能被同步终止,异常必须能被正确捕获并报告,会话终止的时候, 队列必须能被正确地关闭。