前面的博客中咱们说过,在加载数据和预处理数据时使用tf.data.Dataset对象将极大将咱们从建模前的数据清理工做中释放出来,那么,怎么将自定义的数据集加载为DataSet对象呢?这对不少新手来讲都是一个难题,由于绝大多数案例教学都是以mnist数据集做为例子讲述如何将数据加载到Dataset中,而英文资料对这方面的介绍隐藏得有点深。本文就来捋一捋如何加载自定义的图片数据集实现图片分类,后续将继续介绍如何加载自定义的text、mongodb等数据。javascript
若是你已有数据集,那么,请将全部数据存放在同一目录下,而后将不一样类别的图片分门别类地存放在不一样的子目录下,目录树以下所示:css
$ tree flower_photos -L 1html
flower_photos ├── daisy ├── dandelion ├── LICENSE.txt ├── roses ├── sunflowers └── tulipshtml5
全部的数据都存放在flower_photos目录下,每个子目录(daisy、dandelion等等)存放的都是一个类别的图片。若是你已有本身的数据集,那就按上面的结构来存放,若是没有,想操做学习一下,你能够经过下面代码下载上述图片数据集:java
import tensorflow as tf import pathlib data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', fname='flower_photos', untar=True) data_root = pathlib.Path(data_root_orig) print(data_root) # 打印出数据集所在目录
下载好后,建议将整个flower_photos目录移动到项目根目录下。node
import tensorflow as tf import random import pathlib data_path = pathlib.Path('./data/flower_photos') all_image_paths = list(data_path.glob('*/*')) all_image_paths = [str(path) for path in all_image_paths] # 全部图片路径的列表 random.shuffle(all_image_paths) # 打散 image_count = len(all_image_paths) image_count
3670
查看一下前5张:python
all_image_paths[:5]
['data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg', 'data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg', 'data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg', 'data/flower_photos/roses/459042023_6273adc312_n.jpg', 'data/flower_photos/roses/16149016979_23ef42b642_m.jpg']
读取图片的同时,咱们也不能忘记图片与标签的对应,要建立一个对应的列表来存放图片标签,不过,这里所说的标签不是daisy、dandelion这些具体分类名,而是整型的索引,毕竟在建模的时候y值通常都是整型数据,因此要建立一个字典来创建分类名与标签的对应关系:jquery
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir()) label_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
label_to_index = dict((name, index) for index, name in enumerate(label_names)) label_to_index
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
for image, label in zip(all_image_paths[:5], all_image_labels[:5]): print(image, ' ---> ', label)
data/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg ---> 3 data/flower_photos/roses/15222804561_0fde5eb4ae_n.jpg ---> 2 data/flower_photos/daisy/18622672908_eab6dc9140_n.jpg ---> 0 data/flower_photos/roses/459042023_6273adc312_n.jpg ---> 2 data/flower_photos/roses/16149016979_23ef42b642_m.jpg ---> 2
好了,如今咱们能够建立一个Dataset了:linux
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
不过,这个ds可不是咱们想要的,毕竟,里面的元素只是图片路径,因此咱们要进一步处理。这个处理包含读取图片、从新设置图片大小、归一化、转换类型等操做,咱们将这些操做通通定义到一个方法里:android
def load_and_preprocess_from_path_label(path, label): image = tf.io.read_file(path) # 读取图片 image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, [192, 192]) # 原始图片大小为(266, 320, 3),重设为(192, 192) image /= 255.0 # 归一化到[0,1]范围 return image, label
image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>
这时候,其实就已经将自定义的图片数据集加载到了Dataset对象中,不过,咱们还能秀,能够继续shuffle随机打散、分割成batch、数据repeat操做。这些操做有几点须要注意: (1)先shuffle、repeat、batch三种操做顺序有讲究:
(2)shuffle操做时,buffer_size越大,打乱效果越好,但消耗内存越大,可能形成延迟。
推荐经过使用 tf.data.Dataset.apply 方法和融合过的 tf.data.experimental.shuffle_and_repeat 函数来执行这些操做:
ds = image_label_ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count)) BATCH_SIZE = 32 ds = ds.batch(BATCH_SIZE)
好了,至此,本文内容其实就结束了,由于已经将自定义的图片数据集加载到了Dataset中。
下面的内容做为扩展阅读。
上面的方法是简单的在每次epoch迭代中单独读取每一个文件,在本地使用 CPU 训练时这个方法是可行的,可是可能不足以进行GPU训练而且彻底不适合任何形式的分布式训练。
可使用tf.data.Dataset.cache在epoch迭代过程间缓存计算结果。这能极大提高程序效率,特别是当内存能容纳所有数据时。
在被预处理以后(解码和调整大小),图片就被缓存了:
ds = image_label_ds.cache() # 缓存 ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
使用内存缓存的一个缺点是必须在每次运行时重建缓存,这使得每次启动数据集时有相同的启动延迟。若是内存不够容纳数据,使用一个缓存文件:
ds = image_label_ds.cache(filename='./cache.tf-data') ds = ds.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
https://tensorflow.google.cn/tutorials/load_data/images