使用多gpu訓練時
model = torch.nn.DataParallel(model, device_ids=[1, 2, 3, 4])
若模型採用多GPU訓練,則在模型保存時:
torch.save(model.module.state_dict(), model_out_path)
若單GPU則:
torch.save(mode.state_dict(), model_out_path)
使用多gpu訓練時
model = torch.nn.DataParallel(model, device_ids=[1, 2, 3, 4])
若模型採用多GPU訓練,則在模型保存時:
torch.save(model.module.state_dict(), model_out_path)
若單GPU則:
torch.save(mode.state_dict(), model_out_path)