工具代碼粘貼3——pytorch載入模型

    print('loading checkpoint.......')
    model_dict = model.state_dict()
    pretrained_dict = torch.load(weight_path)
    pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print('loaded successfully!')

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章