這是一個造輪子的過程,但是從頭構建LSTM能夠使我們對體系結構進行更加了解,並將我們的研究帶入下一個層次。
LSTM單元是遞歸神經網絡深度學習研究領域中最有趣的結構之一:它不僅使模型能夠從長序列中學習,而且還爲長、短期記憶創建了一個數值抽象,可以在需要時相互替換。
在這篇文章中,我們不僅將介紹LSTM單元的體系結構,還將通過PyTorch手工實現它。
最後但最不重要的是,我們將展示如何對我們的實現做一些小的調整,以實現一些新的想法,這些想法確實出現在LSTM研究領域,如peephole。
LSTM體系結構
LSTM被稱爲門結構:一些數學運算的組合,這些運算使信息流動或從計算圖的那裏保留下來。因此,它能夠“決定”其長期和短期記憶,並輸出對序列數據的可靠預測:
LSTM單元中的預測序列。注意,它不僅會傳遞預測值,而且還會傳遞一個c,c是長期記憶的代表
遺忘門
遺忘門(forget gate)是輸入信息與候選者一起操作的門,作爲長期記憶。請注意,在輸入、隱藏狀態和偏差的第一個線性組合上,應用一個sigmoid函數:
sigmoid將遺忘門的輸出“縮放”到0-1之間,然後,通過將其與候選者相乘,我們可以將其設置爲0,表示長期記憶中的“遺忘”,或者將其設置爲更大的數字,表示我們從長期記憶中記住的“多少”。
新型長時記憶的輸入門及其解決方案
輸入門是將包含在輸入和隱藏狀態中的信息組合起來,然後與候選和部分候選c’'u t一起操作的地方:
在這些操作中,決定了多少新信息將被引入到內存中,如何改變——這就是爲什麼我們使用tanh函數(從-1到1)。我們將短期記憶和長期記憶中的部分候選組合起來,並將其設置爲候選。
單元的輸出門和隱藏狀態(輸出)
之後,我們可以收集o_t作爲LSTM單元的輸出門,然後將其乘以候選單元(長期存儲器)的tanh,後者已經用正確的操作進行了更新。網絡輸出爲h_t。
LSTM單元方程
在PyTorch上實現
import math
import torch
import torch.nn as nn
我們現在將通過繼承nn.Module,然後還將引用其參數和權重初始化,如下所示(請注意,其形狀由網絡的輸入大小和輸出大小決定):
class NaiveCustomLSTM(nn.Module):
def __init__(self, input_sz: int, hidden_sz: int):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
#i_t
self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
#f_t
self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
#c_t
self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
#o_t
self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
self.init_weights()
要了解每個操作的形狀,請看:
矩陣的輸入形狀是(批量大小、序列長度、特徵長度),因此將序列的每個元素相乘的權重矩陣必須具有該形狀(特徵長度、輸出長度)。
序列上每個元素的隱藏狀態(也稱爲輸出)都具有形狀(批大小、輸出大小),這將在序列處理結束時產生輸出形狀(批大小、序列長度、輸出大小)。-因此,將其相乘的權重矩陣必須具有與單元格的參數hidden_sz相對應的形狀(output_size,output_size)。
這裏是權重初始化,我們將其用作PyTorch默認值中的權重初始化nn.Module:
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
前饋操作
前饋操作接收init_states參數,該參數是上面方程的(h_t,c_t)參數的元組,如果不引入,則設置爲零。然後,我們對每個保留(h_t,c_t)的序列元素執行LSTM方程的前饋,並將其作爲序列下一個元素的狀態引入。
最後,我們返回預測和最後一個狀態元組。讓我們看看它是如何發生的:
def forward(self,x,init_states=None):
"""
assumes x.shape represents (batch_size, sequence_size, input_size)
"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device),
)
else:
h_t, c_t = init_states
for t in range(seq_sz):
x_t = x[:, t, :]
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
#reshape hidden_seq p/ retornar
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
優化版本
這個LSTM在運算上是正確的,但在計算時間上沒有進行優化:我們分別執行8個矩陣乘法,這比矢量化的方式慢得多。我們現在將演示如何通過將其減少到2個矩陣乘法來完成,這將使它更快。
爲此,我們設置了兩個矩陣U和V,它們的權重包含在4個矩陣乘法上。然後,我們對已經通過線性組合+偏置操作的矩陣執行選通操作。
通過矢量化操作,LSTM單元的方程式爲:
class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
gates = x_t @ self.W + h_t @ self.U + self.bias
i_t, f_t, g_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.tanh(gates[:, HS*2:HS*3]),
torch.sigmoid(gates[:, HS*3:]), # output
)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
最後但並非最不重要的是,我們可以展示如何優化,以使用LSTM peephole connections。
LSTM peephole
LSTM peephole對其前饋操作進行了細微調整,從而將其更改爲優化的情況:
如果LSTM實現得很好並經過優化,我們可以添加peephole選項,並對其進行一些小的調整:
class CustomLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, peephole=False):
super().__init__()
self.input_sz = input_sz
self.hidden_size = hidden_sz
self.peephole = peephole
self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, x,
init_states=None):
"""Assumes x is of shape (batch, sequence, feature)"""
bs, seq_sz, _ = x.size()
hidden_seq = []
if init_states is None:
h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
torch.zeros(bs, self.hidden_size).to(x.device))
else:
h_t, c_t = init_states
HS = self.hidden_size
for t in range(seq_sz):
x_t = x[:, t, :]
# batch the computations into a single matrix multiplication
if self.peephole:
gates = x_t @ U + c_t @ V + bias
else:
gates = x_t @ U + h_t @ V + bias
g_t = torch.tanh(gates[:, HS*2:HS*3])
i_t, f_t, o_t = (
torch.sigmoid(gates[:, :HS]), # input
torch.sigmoid(gates[:, HS:HS*2]), # forget
torch.sigmoid(gates[:, HS*3:]), # output
)
if self.peephole:
c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
h_t = torch.tanh(o_t * c_t)
else:
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
# reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
我們的LSTM就這樣結束了。如果有興趣大家可以將他與torch LSTM內置層進行比較。