Albert_zh轉化爲pytorch版本

背景
由於google提供bert_base_zh的參數太多,模型大太,大約400M,無論是使用bert進行fine-tuning或者是對bert進行再訓練的訓練成本會變大。所以就對bert進行了一些優化。

主要從以下幾個點對bert進行了優化:

  1. 詞嵌入向量的因式分解
    O(VH)>O(VE+EH)O(V*H)->O(V*E +E*H)
    其中V爲字典中詞的個數,H爲隱藏層size,E是Albert中因式分解的的一個變量。以Albert_xxlarge爲例,V=30000,H=4096,E=128,那麼原來的個數是VH=300004096=1.23億個參數,現在變爲VE+EH = 30000128 + 1284096=436萬,縮小爲原來的1/28.
  2. 跨層參數共享
    參數共享能夠顯著的減小參數。共享參數分爲全連接層、注意力層的參數共享,但是注意力層的參數對效果的減弱影響小一些。
  3. 段落連續任務
    除了bert的mask任務及NSP任務,增加了一個段落連續任務。正例:使用一個文檔中連續的兩個文本段落,負例是使用一個文檔的連續的兩個段落,但是位置調換了。
  4. 去掉了dropout
    發現最大的模型訓練了100萬步後,還是沒有過擬合,說明模型的容量還可以更大一些,就移除了dropout。其實dropout是隨機失活一些節點,本質上還是減小模型。
  5. 爲了加快訓練,使用了LAMB作爲優化器,可以使用大的batch_size。
  6. 使用了n-gram來做mask語音模型。

但是有一個問題如何把tensorflow版本的albert轉化爲Pytorch可以使用的呢?

  1. 下載albert_tiny_google_zh,一定是google版本的,https://storage.googleapis.com/albert_zh/albert_tiny_zh_google.zip
  2. git clone albert_pytorch:https://github.com/lonePatient/albert_pytorch
  3. 使用如下命令進行轉化:
python convert_albert_tf_checkpoint_to_pytorch.py \
    --tf_checkpoint_path=./prev_trained_model/albert_tiny_zh \
    --bert_config_file=./prev_trained_model/albert_tiny_zh/albet_config_tiny_g.json \
    --pytorch_dump_path=./prev_trained_model/albet_tiny_zh/pytorch_model.bin
  1. 使用transformers加載模型並使用

# -*- encoding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
from transformers import AlbertModel, BertTokenizer, AutoModel, AutoTokenizer

import os
from os.path import dirname, abspath
import torch
root_dir = dirname(dirname(dirname(abspath(__file__))))

if __name__ == '__main__':
    albert_path = os.path.join(root_dir, 'pretrained/albert_tiny_zh_pytorch')
    # 加載模型
    model = AutoModel.from_pretrained(albert_path)
    # 加載tokenizer,這裏使用BertTokenizer,如果使用AutoTokenizer會報錯。
    tokenizer = BertTokenizer.from_pretrained(albert_path)
    print(model)
    tokens = tokenizer.encode('我愛中國共產黨',add_special_tokens=True)
    print(tokens)
    predict  = model(torch.tensor(tokens).unsqueeze(0))
    print(predict[0].size()) # this is the last hidden_state with size of [batch_size, seq_len, hidden_size]
    print(predict[1].size())
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章