以前对Pytorch 1.0 的Dataparallel的使用方法一直似懂非懂,老是会碰到各类莫名其妙的问题,今天就好好从源头梳理一下,更好地理解它的原理或者说说下步骤。python
源码地址: https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.pygit
首先咱们一行一行地来看一下Dataparallel是如何初始化的。github
super
就是继承torch.nn.Module父类,这里不作解释output_device
表示输出到哪个GPU上,默认是第一个GPU,注意这个第一个是device_ids列表上的第一个,因此若是你有三个GPU,而你在将model复制到cuda上时写的代码是model.cuda(1)
或者model.cuda(2)
,则会报错,由于device_ids
是[0,1,2].其第一个元素是0。这一点能够在后面的forward
函数中看到。def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = [] return if device_ids is None: device_ids = list(range(torch.cuda.device_count())) if output_device is None: output_device = device_ids[0] self.dim = dim self.module = module self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids)) self.output_device = _get_device_index(output_device, True) self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0])) _check_balance(self.device_ids) if len(self.device_ids) == 1: self.module.cuda(device_ids[0])
下面进入到重头戏:Dataparallel的forward函数。app
def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) return self.gather(outputs, self.output_device)
scatter
函数def scatter(inputs, target_gpus, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: res = scatter_map(inputs) finally: scatter_map = None return res
replica
函数,这个函数比较复杂,就不解释了,感兴趣的能够阅读一下源码:https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/replicate.py 。不过它的主要做用就是将模型复制到多个GPU上。parallel_apply
做用就是并行地在多个GPU上计算模型,每一个模型是同样的,只不过输入数据是不同的,由于前面将数据平均划分了。例如你有两个GPU,一个batch大小是64,那么两个GPU分别处理batch大小为32的数据。gather
到一块儿,传送到output_device
,即第一个GPU设备上。