pytorch错误解决:Missing key(s) in state_dict: Unexpected key(s) in state_dict:

在进行模型测试时报错:测试

Missing key(s) in state_dict: xxxxxxxxxx.net

Unexpected key(s) in state_dict:xxxxxxxxxxcode

 

报错缘由:blog

在模型训练时有加上:【能够加速训练速度】get

model = nn.DataParallel(model)博客

#cudnn.benchmark = Trueio

可是在模型测试推断时,在模型参数被加载到模型前没有加这句话,故报出上面的错误。class

 

解决:model

在模型参数被加载到模型前加下面的语句:map

model = nn.DataParallel(model)

#cudnn.benchmark = True

 

同时在加载模型的时候,添加一个False

 

weight = os.path.join(weight_path)
chkpt = torch.load(weight, map_location=self.__device)
self.__model.load_state_dict(chkpt, False)

 

参考博客:

https://blog.csdn.net/u013925378/article/details/104749000/

https://blog.csdn.net/sinat_34054843/article/details/88046041

相关文章
相关标签/搜索