『PyTorch』模型的保存與加載

序列化與反序列化

序列化

torch.save(obj, f)
  • 主要參數:

    • obj:對象
    • f:輸出路徑
  • 例如

    • 保存整個模型
      torch.save(net, path)
    • 保存模型參數
      state_dict=net.state_dict()
      torch.save(state_dict, path)

反序列化

torch.load(f, map_location=None)
  • 主要參數:
    • f:文件路徑
    • map_location:指定存放位置,cpu or gpu
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章