PyTorch 报错:ModuleAttributeError: ‘DataParallel‘ object has no attribute ‘ xxx (已解决)

PyTorch 报错:ModuleAttributeError: 'DataParallel' object has no attribute ' xxx (已解决)

 

这个问题中 ,‘XXX’ 通常就是代码里面的须要优化的模型名称,例如,个人模型里定义了 optimizer_G 和 optimizer_D 两个网络(生成器网络和判别器网络)。python

问题缘由:

在 train.py 中,调用它们时,直觉地写成了 model.optimizer_G 的格式,以下:网络

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.optimizer_G, model.optimizer_D

然而,其实这时 model 转换成了 model.module。优化

 

解决方法:

在 ‘ model. ’ 后面加一个 ‘ module. ’ 。spa

将 model.optimizer_G 改为 model.module.optimizer_Gcode

将 model.optimizer_D 改为 model.module.optimizer_Dit

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.module.optimizer_G, model.module.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.module.optimizer_G, model.module.optimizer_D
相关文章
相关标签/搜索