人人都會機器翻譯系列1 - seq2seq模型 ?
相關鏈接:
github ?(喜歡的話給個星星吧~)
郵箱地址?([email protected])
QQ?(779388649)
歡迎通過各種方式隨時交流
模型發展簡述 ?
seq2seq,全稱爲Sequence to Sequence,即爲傳統的Encoder-Decoder模型,該技術爲深度神經網絡模型最爲經典的案例,突破了輸入序列大小固定的限制,使得經典深度學習模型在機器翻譯、人機交互、自動文摘等領域得到了突破性的進展。
本着化繁爲簡的目的,我們使用PyTorch實現了一個簡易的seq2seq機器翻譯模型,模型只保留了最基礎的Encoder-編碼器、Decoder-解碼器部分,去除掉了許多修飾成分(Attention、Dropout、batch等),方便理解最基礎的Encoder、Decoder工作原理。
Part-I(數據加載部分 - dataLoader.py)
加載數據是神經網絡模型必不可少的一部分,因爲所有的深度學習模型均爲數據驅動的,若沒有足夠的數據支撐,模型很難學習到最優的參數,達到不錯的效果。同時,每一個模型均需要生成對應的數據格式以滿足模型訓練的需求。所以,數據加載部分是至關重要的一部分。
如果你想了解dataLoader.py部分的原始代碼,請登錄github查看源代碼。
並可以通過以下方式運行此程序:
python3 dataLoader.py
Part-II (模型定義部分 - seq2seq.py)
該代碼部分定義了整個程序最核心的部分,機器翻譯模型(encoder&decoder)部分。如下圖所示:
Part-III (模型訓練部分 - train.py)
python3 train.py -h
usage: train.py [-h] --epoch_num EPOCH_NUM [--embedding_size EMBEDDING_SIZE]
[--hidden_size HIDDEN_SIZE] [--model_path MODEL_PATH]
[--srcLang SRCLANG] [--tgtLang TGTLANG]
optional arguments:
-h, --help show this help message and exit
--epoch_num EPOCH_NUM
Number of epoch to train.
--embedding_size EMBEDDING_SIZE
Word Embedding Vector dimension size, default = 300
--hidden_size HIDDEN_SIZE
Hidden size of RNN. default = 300
--model_path MODEL_PATH
The path of encoder and decoder models.
--srcLang SRCLANG The language of source.
--tgtLang TGTLANG The language of target.
使用
CUDA_VISIBLE_DEVICES=3 python3 train.py --epoch_num 1 --embedding_size 300 --hidden_size 300
Part-IV (模型測試部分)
python3 evaluate.py -h
usage: evaluate.py [-h] --encoder ENCODER --decoder DECODER
[--embedding_size EMBEDDING_SIZE]
[--hidden_size HIDDEN_SIZE] [--srcLang SRCLANG]
[--tgtLang TGTLANG]
optional arguments:
-h, --help show this help message and exit
--encoder ENCODER Encoder file path to load trained_encoder's learned
parameters.
--decoder DECODER Decoder file path to load trained_decoder's learned
parameters.
--embedding_size EMBEDDING_SIZE
Word embedding vector dimension size. default = 300
--hidden_size HIDDEN_SIZE
Hidden size of rnn. default = 300
--srcLang SRCLANG The language of source.
--tgtLang TGTLANG The language of target.
使用
CUDA_CISIBLE_DEVICES=3 python3 evaluate.py --encoder model/encoder.pth --decoder model/decoder.pth