【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经创建,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎你们加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.python
参考目录:
微信
本文的代码已经上传公众号后台,回复【PyTorch】获取。
第一次接触到TFrec文件,我也是比较蒙蔽的其实:
机器学习
能够看到文件是.tfrec
后缀的,并且先记住这个文件是186.72MB大小的。函数
正常状况下咱们用于训练的文件夹内部每每会存着成千上万的图片或文本等文件,这些文件一般被散列存放。这种存储方式有一些缺点:学习
而tfrec格式的文件存储形式会很合理的帮咱们存储数据,核心就是tfrec内部使用Protocol Buffer的二进制数据编码方案,这个方案能够极大的压缩存储空间。编码
以前咱们知道一个tfrec文件100多M,这是由于这个tfrec文件内存储了不少的图片,相似于压缩,对tfrec解压缩后能够获取到一部分的数据集,当咱们把所有的rfrec文件都解压缩后,能够获取到所有的数据集。人工智能
值得一提的是,rfrec文件内除了能够存储图片,还能够存储其余的数据,比方说图片的label。字符串,float类型等均可以转换成二进制的方法,因此什么数据类型基本上均可以存储到rfrec文件内,从而简化读取数据的过程。3d
tfrec文件时tensorflow的数据集存储格式,tensorflow能够高效的读取和处理这些数据集,所以我见过有的数据集由于是tfrec文件,因此用TF读取数据集,而后用pytorch训练模型。code
以前提到了tfrec文件里面是有多个样本的,因此tfrec能够为是多个tf.train.Example
文件组成的序列(每个example是一个样本),而后每个tf.train.Example
又是由若干个tf.train.Features
字典组成。这个Features能够理解为这个样本的一些信息,若是是图片样本,那么确定有一个Features是图片像素值数据,一个Features是图片的标签值;若是是预测任务,那么这个Feature可能就是一些字符串类型的特征对象
import tensorflow as tf import glob # 先记录一下要保存的tfrec文件的名字 tfrecord_file = './train.tfrec' # 获取指定目录的全部以jpeg结尾的文件list images = glob.glob('./*.jpeg') with tf.io.TFRecordWriter(tfrecord_file) as writer: for filename in images: image = open(filename, 'rb').read() # 读取数据集图片到内存,image 为一个 Byte 类型的字符串 feature = { # 创建 tf.train.Feature 字典 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])), # 图片是一个 Bytes 对象 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[1])), 'float':tf.train.Feature(float_list=tf.train.FloatList(value=[1.0,2.0])), 'name':tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(filename)])) } # tf.train.Example 在 tf.train.Features 外面又多了一层封装 example = tf.train.Example(features=tf.train.Features(feature=feature)) # 经过字典创建 Example writer.write(example.SerializeToString()) # 将 Example 序列化并写入 TFRecord 文件
代码中咱们须要注意的地方是:
str.encode
来把字符串转换成字节;这一段代码建议保存下来,方便之后的直接参考和复制。构建tfrec文件对于tensorflow处理图片来讲,应该是绕不过的一个步骤。
如今,咱们运行完上面的代码,应该生成了一个./train.tfrec
文件,下面咱们再对这个文件进行读取。
import tensorflow as tf dataset = tf.data.TFRecordDataset('./train.tfrec') def decode(example): feature_description = { 'image': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.int64), 'float': tf.io.FixedLenFeature([1, 2], tf.float32), 'name': tf.io.FixedLenFeature([], tf.string) } feature_dict = tf.io.parse_single_example(example, feature_description) feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image']) # 解码 JEPG 图片 return feature_dict dataset = dataset.map(decode).batch(4) for i in dataset.take(1): print(i['image'].shape) print(i['label'].shape) print(i['float'].shape) print(bytes.decode(i['name'][0].numpy()))
tf.data.TFRecordDataset
,进行读取,建立了一个dataset,可是这个dataset并不能直接使用,须要对tfrec中的example进行一些解码;tf.io.parse_single_example
方法,从example中提取到对应的特征;tf.io.decode_jpeg()
来把字符串解码成一个tensor张量。.batch(4)
把数据集每个batch包含四个样本。上面代码输出的结果为:
须要注意的是这个如何把name转换成string类型的,若是已经在本地跑完了上面的代码,能够本身看看i['name']是一个什么类型的,而后本身试试如何转换成string类型的。上面的代码是能成功转换的。
下一次的内容就是如何构建模型,而后怎么把数据集喂给模型。