pytorch加載bert模型報錯

背景

使用pytorch加載huggingface下載的albert-base-chinede模型出錯

Exception has occurred: OSError
Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

模型地址:https://huggingface.co/models?search=albert_chinese

方法一:

參考以下文章刪除緩存目錄,問題還是存在
https://blog.csdn.net/znsoft/article/details/107725285
https://github.com/huggingface/transformers/issues/6159

方法二:

使用另一臺電腦加載相同模型,加載成功,查看兩臺電腦的torch、transformers版本,發現一個torch爲1.1,另一個爲torch1.7.x
參考pytorch官網,torch1.6之後修改了模型保存方式,高版本保存的模型,低版本無法加載

The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.

解決方法:

  1. 升級torch爲高版本
  2. 如果因爲cuda兼容等問題無法升級,可以在高版本上加載模型,然後重新save並添加_use_new_zipfile_serialization=False
from transformers import *
import torch

pretrained = 'D:/07_data/albert_base_chinese'
tokenizer = BertTokenizer.from_pretrained(pretrained)
model = AlbertForMaskedLM.from_pretrained(pretrained)

# 它包裝在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), 'pytorch_model_unzip.bin', _use_new_zipfile_serialization=False)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章