机器学习中数据读取是很重要的一个环节,TensorFlow也提供了不少实用的方法,为了不之后时间久了又忘记,因此写下笔记以备往后查看。python
首先咱们看看最普通的状况:session
# 建立0-10的数据集,每一个batch取个数。 dataset = tf.data.Dataset.range(10).batch(6) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(2): value = sess.run(next_element) print(value)
输出结果机器学习
[0 1 2 3 4 5] [6 7 8 9]
由结果咱们能够知道TensorFlow能很好地帮咱们自动处理最后一个batch的数据。学习
可是若是上面for循环次数超过2会怎么样呢?也就是说若是 **循环次数*批数量 > 数据集数量** 会怎么样?咱们试试看:spa
dataset = tf.data.Dataset.range(10).batch(6) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: >>==for i in range(3):==<< value = sess.run(next_element) print(value)
输出结果code
[0 1 2 3 4 5] [6 7 8 9] --------------------------------------------------------------------------- OutOfRangeError Traceback (most recent call last) D:\Continuum\anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1277 try: ... ...省略若干信息... ... OutOfRangeError (see above for traceback): End of sequence [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]
能够知道超过范围了,因此报错了。element
为了解决上述问题,repeat方法登场。仍是直接看例子吧:资源
dataset = tf.data.Dataset.range(10).batch(6) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
输出结果get
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]
能够知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,因此此次即便for循环次数是4也依旧能正常读取数据,而且都能完整把数据读取出来。同理,若是把for循环次数设置为大于4,那么也仍是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?因此更简便的办法就是对repeat方法不设置重复次数,效果见以下:it
dataset = tf.data.Dataset.range(10).batch(6) dataset = dataset.repeat() iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(6): value = sess.run(next_element) print(value)
输出结果:
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]
此时不管for循环多少次都不怕啦~~
仔细看能够知道上面全部输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,因此咱们须要将数据打乱,这样每批次训练的时候所用到的数据集是不同的,这样啊能够提升模型训练效果。
另外shuffle前须要设置buffer_size:
dataset = tf.data.Dataset.range(10).shuffle(2).batch(6) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
输出结果:
[1 0 2 4 3 5] [7 8 9 6] [1 2 3 4 0 6] [7 8 9 5]
注意:shuffle的顺序很重要,通常建议是最开始执行shuffle操做,由于若是是先执行batch操做的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:
dataset = tf.data.Dataset.range(10).batch(6).shuffle(10) dataset = dataset.repeat(2) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with tf.Session() as sess: for i in range(4): value = sess.run(next_element) print(value)
输出结果:
[0 1 2 3 4 5] [6 7 8 9] [0 1 2 3 4 5] [6 7 8 9]