RuntimeError: Error(s) in loading state_dict in XXX model: - Missing key(s) in state_dict:"..."
Describe
pytorch 上,训练 DataParrallel
模型,保存模型时使用 torch.save(model.state_dict(), path)
,
再次导入模型时:
model = Model()
model.load_state_dict(torch.load(path))
此时报错:
RuntimeError: Error(s) in loading state_dict in XXX model:
Missing key(s) in state_dict:"...
...
..."
Solution
解决方法很简单,在使用 load_state_dict
前,将 model
转换成 DataParrallel
类型:
model = Model()
model = torch.nn.DataParrallel(model, device_ids=devices).cuda()
model.load_state_dict(torch.load(path))
这是要和保存的 state_dict
的类型对应,如果训练的模型是普通模块 nn.Module
的子类,那么就不需要转换类型。