Pytorch划分数据集的方法

以前用过sklearn提供的划分数据集的函数,以为超级方便。可是在使用TensorFlow和Pytorch的时候一直找不到相似的功能,以前搜索的关键字都是“pytorch split dataset”之类的,可是搜出来仍是没有我想要的。结果今天见鬼了忽然看见了这么一个函数torch.utils.data.Subset。个人天,为何超级开心hhhh。终于不用每次都手动划分数据集了。html

torch.utils.data

Pytorch提供的对数据集进行操做的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSamplerpython

torch的这个文件包含了一些关于数据集处理的类:微信

  • class torch.utils.data.Dataset: 一个抽象类, 全部其余类的数据集类都应该是它的子类。并且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
  • class torch.utils.data.TensorDataset: 封装成tensor的数据集,每个样本都经过索引张量来得到。
  • class torch.utils.data.ConcatDataset: 链接不一样的数据集以构成更大的新数据集。
  • class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
  • class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
  • torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分红没有重叠的新数据集组合。
  • class torch.utils.data.Sampler(data_source):全部采样的器的基类。每一个采样器子类都须要提供 iter 方-法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
  • class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
  • class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
  • class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
  • class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的几率来采样样本。
  • class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其余的采样器。
  • class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器能够约束数据加载进数据集的子集。

示例

下面Pytorch提供的划分数据集的方法以示例的方式给出:dom

SubsetRandomSampler

...

dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
    # Train:   
    for batch_index, (faces, labels) in enumerate(train_loader):
        # ...

random_split

...

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

参考:机器学习

<b style="color:tomato;"></b>函数

<footer style="color:white;;background-color:rgb(24,24,24);padding:10px;border-radius:10px;"><br> <h3 style="text-align:center;color:tomato;font-size:16px;" id="autoid-2-0-0"><br> <br> <center> <span>微信公众号:AutoML机器学习</span><br> <img src="https://ask.qcloudimg.com/draft/1215004/21ra82axnz.jpg" style="width:200px;height:200px"> </center> <b>MARSGGBO</b><b style="color:white;"><span style="font-size:25px;">♥</span>原创</b><br> <span>若有意合做或学术讨论欢迎私戳联系~<br>邮箱:marsggbo@foxmail.com</span> <b style="color:white;"><br> 2019-3-8<p></p> </b><p><b style="color:white;"></b><br> </p></h3><br> </footer>学习

相关文章
相关标签/搜索