首先介绍数据读取问题,如今TensorFlow官方推荐的数据读取方法是使用tf.data.Dataset,具体的细节不在这里赘述,看官方文档更清楚,这里主要记录一下官方文档没有提到的坑,以示"后人"。由于是记录踩过的坑,因此行文混乱,见谅。html
不感兴趣的可跳过此节。python
最近在研究ENAS的代码,这个网络的做用是基于加强学习,可以自动生成合适的网络结构。原做者使用TensorFlow在cifar10上成功自动生成了网络结构,并取得了不错的效果。git
但问题来了,此时我须要将代码转移到本身的数据集上,都知道cifar10图像大小是32*32,并非特别大,因此原做者"丧心病狂"
地采用了一次性将数据读进显存的操做,丝绝不考虑我等渣渣的感觉。个人数据集原图基本在500*800或以上,通过反复试验,若是采用源代码我必须将图像经过缩放和中心裁剪到160*160才能正常运行,并且运行结果并非很理想,十分类跑了一天左右最好的结果才30%左右。github
我在想若是把图片放大后是否会提升准确度,因此第一个坑是修改数据读取方式,适应大数据集读取。网络
再仔细阅读源代码后我还发现做者使用了tf.train.shuffle_batch
这个函数用来批量读取,这个函数也让我头疼了好久,由于一直不知道它和tf.data.Dataset.batch.shuffle()
有什么区别,因此第二个坑时tf.train.shuffle_batch
和tf.data.Dataset.batch.shuffle()
到底什么关系(区别)ide
tf.train.batch
和tf.data.Dataset.batch.shuffle()
什么区别其实这两个谈不上什么区别,由于后者是前者的升级版,233333。函数
官方文档对tf.train.batch
的描述是这样的:学习
THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.batch(batch_size) (or padded_batch(...) if dynamic_pad=True).大数据
在这里我也推荐你们用tf.data,由于他相比于原来的tf.train.batch好用太多。ui
这里的大数据集指的是稍微比较大的,像ImageNet这样的数据集还没尝试过。因此下面的方法不敢确定是否使用于ImageNet。
要想读取大数据集,我找到的官方给出的方案有两种:
个人数据集是以已经分好类的文件夹进行存储的,大体结构是这样的
├───test │ ├───Acne_Vulgaris │ ├───Actinic_solar_Damage__Actinic_Keratosis │ ├───Basal_Cell_Carcinoma │ ├───Rosacea │ └───Seborrheic_Keratosis ├───train │ ├───Acne_Vulgaris │ ├───Actinic_solar_Damage__Actinic_Keratosis │ ├───Basal_Cell_Carcinoma │ ├───Rosacea │ └───Seborrheic_Keratosis └───valid ├───Acne_Vulgaris ├───Actinic_solar_Damage__Actinic_Keratosis ├───Basal_Cell_Carcinoma ├───Rosacea └───Seborrheic_Keratosis
个人方法很是适合懒人,具体流程以下:
pytorch提供了torchvision这个库,这个库堪称瑰宝,torchvision.datasets里有个函数是ImageFolder,你只须要指明路径便可把图片数据都读进来,不用再苦逼地手写for循环遍历了。其余的细节,好比data augmentation等等就不介绍了,具体代码可参看官方文档以及以下连接: https://github.com/marsggbo/enas/blob/master/src/skin5_placeholder/data_utils.py
假设上一步已经图像数据读取完毕,并保存成numpy文件,下面参看官方文档例子
# 读取numpy数据 with np.load("/var/data/training_data.npy") as data: features = data["features"] labels = data["labels"] # 查看图像和标签维度是否保持一致 assert features.shape[0] == labels.shape[0] # 建立placeholder features_placeholder = tf.placeholder(features.dtype, features.shape) labels_placeholder = tf.placeholder(labels.dtype, labels.shape) # 建立dataset dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) # 批量读取,打散数据,repeat() dataset = dataset.shuffle(20).batch(5).repeat() # [Other transformations on `dataset`...] dataset_other = ... iterator = dataset.make_initializable_iterator() data_element = iterator.get_nex() sess = tf.Session() sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) for e in range(EPOCHS): for step in range(num_batches): x_batch, y_batch = sess.run(data_element) y_pred = model(x_batch) ... ... sess.close()
插播一条广告:上面代码中batch(), shuffle(), repeat()的具体用法参见Tensorflow datasets.shuffle repeat batch方法。
上面逻辑很清楚:
注意,每次一运行sess.run(data_element)这个语句,TensorFlow会自动的调取下一个批次的数据。不只如此,只要sess.run一个把data_element做为输入的节点,也都会自动调取下一个批次的数据。说的有点绕,看例子就明白了
能够看到若是在读取数据的时候还sess.run与数据有关的操做,那么有的数据就根本没遍历到,因此这个问题要特别注意。
那我为何会连这种坑都能踩到呢,由于原做者的代码写的太“好”了,对于我这种刚入门的人来讲太难理解和修改了。
原做者的代码结构并无写for循环遍历读取数据,而后传入到模型。相反他把数据操做写到了另外一个类(文件)中,好比说在model.py
中他定义了
class Model(): def __init__(): ... def _model(self, img, label): y_pred = other_function(img) acc = calculate_acc(y_pred, label) ...
而后在main.py
中他只是sess.run(model.acc),即
with tf.Session() as sess: ... while epoch < EPOCHS: global_step = sess.run(model.global_step) if global_step % 50: acc = sess.run(model.acc) ... ...
抱怨一下: 它这代码结构写得和官方文档不同,因此一直不知道怎么修改。你若是从最开始看到这,你应该以为很好改啊,可是你看着官方文档真不知道怎么修改,由于最开始我并不知道每次sess.run以后都会自动调用下一个batch的数据,并且也尚未习惯TensorFlow数据流的思惟。在这里特别感谢这个问题帮助我解答了困惑:Tensorflow: create minibatch from numpy array > 2 GB。
因此这种状况怎么读取数据呢?很简单,只须要在循环语句以前初始化迭代器便可。
ops = { "global_step": model.global_step, "acc": model.acc } with tf.Session() as sess: ... sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) while epoch < EPOCHS: global_step = sess.run(ops['global_step']) if global_step % 50: acc = sess.run(ops['acc']) ... ...
若是你想要查看数据是否正确读取,千万不要在上面的while循环中加入这么一行代码x_batch, y_batch=sess.run([model.x_batch, model.y_batch])
,这样就会致使上面所说的数据没法完整遍历的问题。那怎么办呢?
咱们能够考虑修改ops
来获取数据,代码以下:
ops = { "global_step": model.global_step, "acc": model.acc, "x_batch": model.x_batch, "y_batch": model.y_batch } with tf.Session() as sess: ... sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels}) while epoch < EPOCHS: global_step = sess.run(ops['global_step']) if global_step % 50: acc = sess.run([ops["acc"], ops["x_batch"], ops["y_batch"]]) ...
这样之因此能完整遍历,是由于咱们将x_batch和acc放在一块儿啦~,因此这能够当作只是一个运算。