文本分類_acc分數異常

在照搬別人的參數時候,nn.LSMT中有一個參數,batch_first,對它設置了True,於是分數直接下降了70個點。查閱過之後,發現是nn.LSTM中的batch_first是指它接受輸入時,會將第一維的位置,認爲是batch。爲了驗證寫了一個小例子。
爲了方便觀察,設置了batch爲4,句子最長長度爲11

讀入數據

file_path="E:/study_series/2020_3_data/data/corpus/"
train_iter, valid_iter, test_iter,TEXT=generate_data(file_path)
a=list(train_iter)

如下圖所示,爲讀入的一個批次樣本
在這裏插入圖片描述

LSTM batch_first


import torch 
import torch.nn as nn
import random
random.seed(1)
torch.manual_seed(1)

vocab=4000
emb_dim=768
embedding=nn.Embedding(vocab,emb_dim)
print(embedding)

## 定義LSTM
hidden_dim=64
lstm=nn.LSTM(emb_dim, hidden_dim,batch_first=False)
#lstm=nn.LSTM(emb_dim, hidden_dim)
## 數據過一層embedding
embedding_data=embedding(a[0].context.cpu())
print(embedding_data.shape)
## 上一層的輸入再過lstm
lstm_data,_=lstm(embedding_data)
print(lstm_data.shape)
print(lstm_data[:,-1,:])

代碼調試發現,batch_first的值會對最後結構有很大的不同。猜測是對數據的結構進行了不同的理解。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章