RuntimeError: Error(s) in loading state_dict in XXX model: - Missing key(s) in state_dict:"..."

Describe

pytorch 上,训练 DataParrallel 模型,保存模型时使用 torch.save(model.state_dict(), path),

再次导入模型时:

model = Model()
model.load_state_dict(torch.load(path))

此时报错:

RuntimeError: Error(s) in loading state_dict in XXX model:
    Missing key(s) in state_dict:"...
    ...
    ..."

Solution

解决方法很简单,在使用 load_state_dict 前,将 model 转换成 DataParrallel 类型:

model = Model()
model = torch.nn.DataParrallel(model, device_ids=devices).cuda()
model.load_state_dict(torch.load(path))

这是要和保存的 state_dict 的类型对应,如果训练的模型是普通模块 nn.Module 的子类,那么就不需要转换类型。