一些最经常使用的数据集如 MNIST、Fashion MNIST、cifar10/100 在 tf.keras.datasets 中就能找到,但对于其它也经常使用的数据集如 SVHN、Caltech101,tf.keras.datasets 中没有,此时咱们能够在 TensorFlow Datasets 中找找看。python
tensorflow_datasets 里面包含的数据集列表:https://www.tensorflow.org/datasets/catalog/overview#all_datasetsgit
tensorflow_datasets 安装:pip install tensorflow_datasets
github
获得 tf.data.Dataset 对象:code
import tensorflow as tf import tensorflow_datasets as tfds data, info = tfds.load("mnist", with_info=True) print(info) train_data, test_data = data['train'], data['test'] assert isinstance(train_data, tf.data.Dataset) print(train_data)
获得 numpy.ndarray 对象:对象
import tensorflow_datasets as tfds # `batch_size=-1`, will return the full dataset as `tf.Tensor`s. dataset, info = tfds.load("mnist", batch_size=-1, with_info=True) print(info) train, test = dataset["train"], dataset["test"] print(type(train['image'])) train = tfds.as_numpy(train) print(type(train['image'])) print(train['image'].shape) print(train['label'].shape)
tf.data.Dataset 进行简单划分验证集能够参考 https://github.com/tensorflow/datasets/issues/665#issuecomment-502409920ip
若是想对 MNIST 等数据集手动分层随机划分出一个验证集,仍是转化成 numpy.ndarray 比较方便,再使用 sklearn 的 train_test_split 方法一行代码就能够搞定。ci
https://www.tensorflow.org/datasets
https://www.tensorflow.org/datasets/catalog/overview#all_datasets
https://github.com/tensorflow/datasets/blob/master/docs/splits.mdget