使用PyTorch手寫代碼從頭構建LSTM,更深入的理解其工作原理

這是一個造輪子的過程,但是從頭構建LSTM能夠使我們對體系結構進行更加了解,並將我們的研究帶入下一個層次。

LSTM單元是遞歸神經網絡深度學習研究領域中最有趣的結構之一:它不僅使模型能夠從長序列中學習,而且還爲長、短期記憶創建了一個數值抽象,可以在需要時相互替換。

在這篇文章中,我們不僅將介紹LSTM單元的體系結構,還將通過PyTorch手工實現它。

最後但最不重要的是,我們將展示如何對我們的實現做一些小的調整,以實現一些新的想法,這些想法確實出現在LSTM研究領域,如peephole。

LSTM體系結構

LSTM被稱爲門結構:一些數學運算的組合,這些運算使信息流動或從計算圖的那裏保留下來。因此,它能夠“決定”其長期和短期記憶,並輸出對序列數據的可靠預測:

LSTM單元中的預測序列。注意,它不僅會傳遞預測值,而且還會傳遞一個c,c是長期記憶的代表

遺忘門

遺忘門(forget gate)是輸入信息與候選者一起操作的門,作爲長期記憶。請注意,在輸入、隱藏狀態和偏差的第一個線性組合上,應用一個sigmoid函數:

sigmoid將遺忘門的輸出“縮放”到0-1之間,然後,通過將其與候選者相乘,我們可以將其設置爲0,表示長期記憶中的“遺忘”,或者將其設置爲更大的數字,表示我們從長期記憶中記住的“多少”。

新型長時記憶的輸入門及其解決方案

輸入門是將包含在輸入和隱藏狀態中的信息組合起來,然後與候選和部分候選c’'u t一起操作的地方:

在這些操作中,決定了多少新信息將被引入到內存中,如何改變——這就是爲什麼我們使用tanh函數(從-1到1)。我們將短期記憶和長期記憶中的部分候選組合起來,並將其設置爲候選。

單元的輸出門和隱藏狀態(輸出)

之後,我們可以收集o_t作爲LSTM單元的輸出門,然後將其乘以候選單元(長期存儲器)的tanh,後者已經用正確的操作進行了更新。網絡輸出爲h_t。

LSTM單元方程

在PyTorch上實現

import math
import torch
import torch.nn as nn

我們現在將通過繼承nn.Module,然後還將引用其參數和權重初始化,如下所示(請注意,其形狀由網絡的輸入大小和輸出大小決定):

class NaiveCustomLSTM(nn.Module):
    def __init__(self, input_sz: int, hidden_sz: int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        
        #i_t
        self.U_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
        
        #f_t
        self.U_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_t
        self.U_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
        
        #o_t
        self.U_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.V_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()

要了解每個操作的形狀,請看:

矩陣的輸入形狀是(批量大小、序列長度、特徵長度),因此將序列的每個元素相乘的權重矩陣必須具有該形狀(特徵長度、輸出長度)。

序列上每個元素的隱藏狀態(也稱爲輸出)都具有形狀(批大小、輸出大小),這將在序列處理結束時產生輸出形狀(批大小、序列長度、輸出大小)。-因此,將其相乘的權重矩陣必須具有與單元格的參數hidden_sz相對應的形狀(output_size,output_size)。

這裏是權重初始化,我們將其用作PyTorch默認值中的權重初始化nn.Module:

def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

前饋操作

前饋操作接收init_states參數,該參數是上面方程的(h_t,c_t)參數的元組,如果不引入,則設置爲零。然後,我們對每個保留(h_t,c_t)的序列元素執行LSTM方程的前饋,並將其作爲序列下一個元素的狀態引入。

最後,我們返回預測和最後一個狀態元組。讓我們看看它是如何發生的:

def forward(self,x,init_states=None):
        
        """
        assumes x.shape represents (batch_size, sequence_size, input_size)
        """
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        
        if init_states is None:
            h_t, c_t = (
                torch.zeros(bs, self.hidden_size).to(x.device),
                torch.zeros(bs, self.hidden_size).to(x.device),
            )
        else:
            h_t, c_t = init_states
            
        for t in range(seq_sz):
            x_t = x[:, t, :]
            
            i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
            f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
            g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
            o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            
            hidden_seq.append(h_t.unsqueeze(0))
        
        #reshape hidden_seq p/ retornar
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

優化版本

這個LSTM在運算上是正確的,但在計算時間上沒有進行優化:我們分別執行8個矩陣乘法,這比矢量化的方式慢得多。我們現在將演示如何通過將其減少到2個矩陣乘法來完成,這將使它更快。

爲此,我們設置了兩個矩陣U和V,它們的權重包含在4個矩陣乘法上。然後,我們對已經通過線性組合+偏置操作的矩陣執行選通操作。

通過矢量化操作,LSTM單元的方程式爲:

class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
         
    def forward(self, x, 
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t @ self.W + h_t @ self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

最後但並非最不重要的是,我們可以展示如何優化,以使用LSTM peephole connections。

LSTM peephole

LSTM peephole對其前饋操作進行了細微調整,從而將其更改爲優化的情況:


如果LSTM實現得很好並經過優化,我們可以添加peephole選項,並對其進行一些小的調整:

class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz, peephole=False):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.peephole = peephole
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
         
    def forward(self, x, 
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            
            if self.peephole:
                gates = x_t @ U + c_t @ V + bias
            else:
                gates = x_t @ U + h_t @ V + bias
                g_t = torch.tanh(gates[:, HS*2:HS*3])
            
            i_t, f_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            
            if self.peephole:
                c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
                h_t = torch.tanh(o_t * c_t)
            else:
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t)
                
            hidden_seq.append(h_t.unsqueeze(0))
            
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        
        return hidden_seq, (h_t, c_t)

我們的LSTM就這樣結束了。如果有興趣大家可以將他與torch LSTM內置層進行比較。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章