Pytorch 重写Dataloader

是一个官网的例子:torch.nn入门html

通常而言,咱们会根据本身的数据需求继承Dataset(from torch.utils.data import Dataset, DataLoader)重写数据读取函数。或者利用TensorDataset更加简洁实现读取数据。python

抑或利用 torchvision里面的ImageFolder也可管理数据。这几种方法已经能够实现数据读取了,而DataLoader的做用是更加全面管理批量数据:app

 

下面进入正题,MNIST数据利用CNN时须要转换为二维数据,因此须要对初始的线性数据进行转换。通常,能够读取先行数据后在模型中进行view来实现:ide

class Lambda(nn.Module): def __init__(self, func): super().__init__() self.func = func def forward(self, x): return self.func(x) def preprocess(x): return x.view(-1, 1, 28, 28) model = nn.Sequential( Lambda(preprocess), nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AvgPool2d(4), Lambda(lambda x: x.view(x.size(0), -1)), )

文中给出另外一种解决方案:重写DateLoader:将数据处理移到生成器里面函数

def get_data(train_ds, valid_ds, bs): return ( DataLoader(train_ds, batch_size=bs, shuffle=True), DataLoader(valid_ds, batch_size=bs * 2), ) def preprocess(x, y): return x.view(-1, 1, 28, 28), y class WrappedDataLoader: def __init__(self, dl, func): self.dl = dl self.func = func def __len__(self): return len(self.dl) def __iter__(self): batches = iter(self.dl) for b in batches: yield (self.func(*b)) train_dl, valid_dl = get_data(train_ds, valid_ds, bs) train_dl = WrappedDataLoader(train_dl, preprocess) valid_dl = WrappedDataLoader(valid_dl, preprocess)

模型就能够写成这样:spa

model = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), Lambda(lambda x: x.view(x.size(0), -1)), )
相关文章
相关标签/搜索