PyTorch加載預訓練模型

  1. 加載單GPU模型
model = net()
pretrained_dict = torch.load("abc.pth")
model.load_state_dict(pretrained)
  1. 加載多GPU模型
model = net()
pretrained_dict = toch.load("m_abc.pth")
model.module.load_state_dict() # 多GPU要加module
  1. 加載部分預訓練模型參數
model = net()
pretrained_dict = torch.load("abc.pth")
model_dict = model.state_dict()
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章