torch多GPU模型的訓練與保存

使用多gpu訓練時

model = torch.nn.DataParallel(model, device_ids=[1, 2, 3, 4])

若模型採用多GPU訓練,則在模型保存時:

torch.save(model.module.state_dict(), model_out_path)

若單GPU則:

torch.save(mode.state_dict(), model_out_path)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章