pytorch導入預訓練模型部分參數

        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint["epoch"] + 1
        #best_score = checkpoint["best_score"]
        best_score = 0

        print("\t* Training will continue on existing model from epoch {}..."
              .format(start_epoch))

        model_now_dict = model.state_dict()
        print(model_now_dict.keys())

        load_pretrained_dict = (checkpoint["model"])
        print(load_pretrained_dict.keys())
        new_state_dict = {k: v for k, v in load_pretrained_dict.items() if k!="_word_embedding.weight"}
        #new_state_dict = {k: v for k, v in load_pretrained_dict.items()}


        # 1. filter out unnecessary keys
        # 2. overwrite entries in the existing state dict
        model_now_dict.update(new_state_dict)
        model.load_state_dict(model_now_dict)

經過了好幾個關卡。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章