# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'r') as f:
datadict = p.load(f)
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 载入cifar所有数据 """
xs = []
ys = []
for b in range(1, 6):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
复制代码
错误代码以下:
'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence复制代码
因而乎开始各类搜索问题,问大佬,网上的答案都是相似:
然而并无解决问题!仍是错误的!(我大概搜索了一下午吧,都是上面的答案)数据库
哇,就当我很绝望的时候,我终于发现了一个新奇的答案,抱着试一试的态度,尝试了一下:def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y复制代码
居然成功了,这里没有报错了!欣喜之余,我就很好奇,encoding='latin1'究竟是啥玩意呢,之前没有见过啊?因而,我搜索了一下,了解到:
Latin1是ISO-8859-1的别名,有些环境下写做Latin-1。ISO-8859-1编码是单字节编码,向下兼容ASCII,其编码范围是0x00-0xFF,0x00-0x7F之间彻底和ASCII一致,0x80-0x9F之间是控制字符,0xA0-0xFF之间是文字符号。还没等我高兴起来,运行后,又发现了一个问题:
由于ISO-8859-1编码范围使用了单字节内的全部空间,在支持ISO-8859-1的系统中传输和存储其余任何编码的字节流都不会被抛弃。换言之,把其余任何编码的字节流看成ISO-8859-1编码看待都没有问题。这是个很重要的特性,MySQL数据库默认编码是Latin1就是利用了这个特性。ASCII编码是一个7位的容器,ISO-8859-1编码是一个8位的容器。
memory error复制代码
什么鬼?内存错误!哇,原来是数据大小的问题。
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")复制代码
这告诉咱们每批数据都是10000 * 3 * 32 * 32,至关于超过3000万个浮点数。 float数据类型实际上与float64相同,意味着每一个数字大小占8个字节。这意味着每一个批次占用至少240 MB。你加载6这些(5训练+ 1测试)在总产量接近1.4 GB的数据。
for b in range(1, 2):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
复制代码
# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 载入cifar所有数据 """
xs = []
ys = []
for b in range(1, 2):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X) #将全部batch整合起来
ys.append(Y)
Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
import numpy as np
from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# 载入CIFAR-10数据集
cifar10_dir = 'julyedu/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# 看看数据集中的一些样本:每一个类别展现一些
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)
复制代码
顺便看一下CIFAR-10数据组成:
更多内容,可关注个人我的公众号
bash