搭建深度學習框架(三) 循環神經網絡, BPTT的梯度計算

爲什麼需要RNN

RNN是一種擁有記憶的網絡, 一旦網絡接收到了輸入, 就會改變它的隱藏變量. 這個隱藏變量會參與RNN的前向運算, 從而讓之前的輸入x, 能影響現在的輸出o. 具有這種性質的它通常用於處理序列信息. 序列信息比起之前的傳統模式分類有着一些不太好的性質, 語音和文本信號, 都是變長的. 而且文本信息常常是每個詞對應一個one-hot編碼或者一個word-embedding詞向量, 而語音信號的採樣率又非常高, 一段簡單的聲音有可能對應着幾千長度的序列. 變長的數據還比較容易處理, 畢竟我們在數據科學中也會遇到缺失值, 直接補0即可. 但是考慮到數據可能很長也可能很短, 我們想把它們一同處理就必須把所有數據對齊最長的那個.這樣就造成了不必要的算力浪費. 最長的數據的維度可能很高(幾k甚至幾百k), 也就是至少我們輸入層的input_size就很大, 即至少輸入層的參數會非常非常多. 如果設計一個巨型的網絡來處理序列數據顯然是浪費的.
我們需要更好的架構來處理序列數據, 這時RNN就很有用了. RNN可以自由處理變長序列, 但是每次運算的輸入只有一個詞向量那麼大. 這樣就大大節約了參數. 同時也減少了計算.

結構與前向傳播

在這裏插入圖片描述
RNN的計算和前饋網絡相似, 每次RNN前向傳播會同時接收兩個向量, 一個是我們當前時刻的輸入x, 另一個是保存在存儲單元中的向量h. 我們會同時用這兩個向量, 經過兩個線性層, 得到RNN的隱層輸出. 這個隱層輸出會成爲新的存儲單元中的向量h, 參與下一次運算. 而當前時刻我們還會把這個h經過輸出線性層, 再經過一個激活函數(softmax)得到當前時刻的輸出.這個過程如果進行計算圖展開就可以寫成
在這裏插入圖片描述
我們的參數一共有三個線性層, 三個權重矩陣和三個偏置. 我們在實踐時一般會把W和U對應的偏置合二爲一, 也就是這樣的RNN架構需要5種參數.
從0時刻開始, 我們的h一開始會被初始化爲0. 然後, 我們用兩個線性層和一個激活函數計算新的h.
h1=σ(x1U+h0W+b) h_1 = \sigma(x_1U+h_0W+b)
當前時刻的輸出就由h1繼續運算得到
o1=h1V+c o_1 = h_1V+c
然後, 我們會接受新的x輸入, 它和新的h一共繼續這樣運算下去
ht=σ(xtU+ht1W+b) h_t = \sigma(x_tU+h_{t-1}W+b)
ot=htV+c o_t = h_tV+c
這就是最簡單的RNN架構, 如果讓它接收完一整個序列信息, 他就可以輸出一個和整個序列都有相關性的輸出, 然後根據我們想要什麼, 就可以設置合適的損失函數, 並用梯度方法訓練它.

反向傳播(BPTT)

RNN的參數梯度該如何計算呢? 如果你使用Pytorch的計算圖模型來計算梯度, 就會發現這其實並不需要任何其他的backward_fn的設計, 因爲我們只是用了一些激活函數和線性層, 我們之前的推導已經完全夠用. 唯一需要注意的點是, 我們在計算圖中進行了權值共享, 把同一個V,W,U使用了好多遍, 這時要計算導數時, 就需要把每個V,W,U的導數都計算一次, 然後把它們加起來.
在這裏插入圖片描述
這裏我們先計算出圖中每個部分導, 然後再給出參數的導數到底該怎麼計算的公式.
首先t時刻的輸出損失LtL_t和t時刻的h是有直接相關性的, 因爲ot=htV+co_t = h_tV+c, 我們這裏可以直接計算V和c的偏導, 並計算出h關於LtL_t的偏導. 注意這並不是h的全部偏導, 我們還要考慮來自t+x時間的損失Lt+xL_{t+x}的導數.設序列的總長度爲K.
LtV=htTLtot \frac{\partial L_t}{\partial V} = h_t^T\frac{\partial L_t}{\partial o_t}
Ltc=SUMROW Ltot \frac{\partial L_t}{\partial c} = SUMROW\ \frac{\partial L_t}{\partial o_t}
Ltht=LtotVT \frac{\partial L_t}{\partial h_t} = \frac{\partial L_t}{\partial o_t}V^T
LV=k=1KhkTLkyk \frac{\partial L}{\partial V} = \sum_{k=1}^K h_k^T\frac{\partial L_k}{\partial y_k}
Lc=k=1KSUMROW Lkyk \frac{\partial L}{\partial c} = \sum_{k=1}^K SUMROW\ \frac{\partial L_k}{\partial y_k}
從上圖我們知道, 任意hth_t關於損失的導數要同時考慮LtL_tLkL_k所有這些的損失. 對一個Lk,k>tL_k, k>t, 我們要計算它對hth_t的導數如下
Lkht=Lkhki=tk1hi+1hi \frac{\partial L_k}{\partial h_t} = \frac{\partial L_k}{\partial h_k}\prod_{i=t}^{k-1}\frac{\partial h_{i+1}}{\partial h_{i}}
hi+1hi=σWT \frac{\partial h_{i+1}}{\partial h_{i}} = \sigma'W^T
如果我們使用tanh激活函數, 設vt+1=xt+1U+htW+bv_{t+1} = x_{t+1}U+h_tW+b, ht+1=tanh(vt+1)h_{t+1} = tanh(v_{t+1}), 則能寫出hi+1hi\frac{\partial h_{i+1}}{\partial h_{i}}更精確的形式.
hi+1hi=(1hi+12)WT \frac{\partial h_{i+1}}{\partial h_{i}} = (1-h_{i+1}^2)\cdot W^T
這樣我們就能給出任意hth_t關於總損失L的導數完整的形式
Lht=k=tKLkht=k=tKLkhki=tk1(1hi+12)WT \frac{\partial L}{\partial h_t} = \sum_{k = t}^K\frac{\partial L_k}{\partial h_t} = \sum_{k = t}^K\frac{\partial L_k}{\partial h_k}\prod_{i=t}^{k-1}(1-h_{i+1}^2)\cdot W^T
然後任務就是根據Lht\frac{\partial L}{\partial h_{t}}計算W,U和b的導數. 雖然W,U,b是權值共享, 我們還是把不同時刻的它們寫成Wt,Ut,btW_t,U_t,b_t方便描述
LWt=ht1TσLht=ht1T((1ht2)Lht) \frac{\partial L}{\partial W_t} = h_{t-1}^T\sigma'\frac{\partial L}{\partial h_{t}} = h_{t-1}^T ((1-h_{t}^2)\cdot \frac{\partial L}{\partial h_{t}})
LUt=xtTσLht=xtT((1ht2)Lht) \frac{\partial L}{\partial U_t} = x_{t}^T\sigma'\frac{\partial L}{\partial h_{t}} = x_{t}^T ((1-h_{t}^2)\cdot \frac{\partial L}{\partial h_{t}})
Lbt=SUMROW σLht=SUMROW ((1ht2)Lht) \frac{\partial L}{\partial b_t} = SUMROW\ \sigma'\frac{\partial L}{\partial h_{t}} = SUMROW\ ((1-h_{t}^2)\cdot \frac{\partial L}{\partial h_{t}})
權值共享的參數, 最後更新時要把這些不同時刻t得到的導數加起來, 纔是最終的損失函數關於參數的導數
LW=k=1KLWk \frac{\partial L}{\partial W} = \sum_{k=1}^K \frac{\partial L}{\partial W_k}
LU=k=1KLUk \frac{\partial L}{\partial U} = \sum_{k=1}^K \frac{\partial L}{\partial U_k}
Lb=k=1KLbk \frac{\partial L}{\partial b} = \sum_{k=1}^K \frac{\partial L}{\partial b_k}

訓練技巧

很顯然, RNN無法像CNN一樣, 實現超高度的並行化. CNN中, 我們可以在圖像與圖像間並行計算, 也能在卷積核與卷積核間並行計算. 但是RNN不行, RNN任意時刻的輸入都取決於前面時刻的運算, 直到t-1時刻的運算完成, t時刻的運算才能開始. 這就使得RNN的運算很緩慢, 儘管如此, 我們還是希望儘可能並行化計算. 之前使用的mini-batch實際上仍然能在RNN中使用. 雖然一個batch中的數據有長有短, 我們只需要按照batch中最長的那個把所有數據在時間上對齊, 不足的補零. 這樣時刻t就能一次輸入n個m維的向量x, 即輸入是n行m列矩陣{x1t;x2t...,xnt}\{x_{1t};x_{2t}...,x_{nt}\}.
另外RNN的訓練有一些很不好的性質, 我們看Lht\frac{\partial L}{\partial h_{t}}的公式, 它是把很多的ht+1ht=11vi+12WT\frac{\partial h_{t+1}}{\partial h_{t}}=\frac{1}{1-v_{i+1}^2}\cdot W^T做連乘, 而因爲WW權值共享, 每次的ht+1ht\frac{\partial h_{t+1}}{\partial h_{t}}相差並不會很大. 設想, 如果所有ht+1ht1.01\frac{\partial h_{t+1}}{\partial h_{t}}\simeq 1.01, 序列長度爲1000, 對Lh1\frac{\partial L}{\partial h_{1}}計算導數將會有1.01^{100} = 20959那麼大, 這就是RNN的短板, 梯度爆炸. 在優化的目標函數中會非常常見這樣的"懸崖", 如果我們用這個梯度更新參數, 將會讓參數直接廢掉. 爲此我們會用一些簡單的處理方式來緩解這種影響, 比如每隔T個時間單位就清零前面傳來的梯度, 或是更新參數時設計clip截斷梯度.

RNN實現

import torch
import math
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

class Tanh:
    def __init__(self):
        self.out = None
        
    def forward(self, x):
        self.out = torch.tanh(x)
        return self.out
    
    def backward(self, dz):
        return dz*(1-self.out**2)
    
    def __call__(self, X):
        return self.forward(X)
    
    
class RNN:
    def __init__(self, input_sz, hidden_sz, output_sz,
                LEARNING_RATE=0.01):
        '''
        單隱層RNN, 接收輸入向量x, 長度爲input_sz
        隱層輸出向量h, 長度hidden_sz, 輸出層輸出向量o, 長度output_sz
        參數有V,c,W,U,b
        '''
        self.hidden_sz = hidden_sz
        self.lr = LEARNING_RATE
        
        self.W = torch.randn(hidden_sz,hidden_sz)*math.sqrt(1/hidden_sz)
        # 從ht-1到ht的連接權
        self.U = torch.randn(input_sz,hidden_sz)*math.sqrt(2/(input_sz+hidden_sz)) 
        # 從X到ht的連接權
        self.b  = torch.randn(hidden_sz)*math.sqrt(2/hidden_sz)       
        # 在隱層激活前增加的偏置
        self.V = torch.randn(hidden_sz,output_sz)*math.sqrt(2/(output_sz+hidden_sz)) 
        # 從h到output的連接權
        self.c = torch.randn(output_sz)*math.sqrt(2/output_sz)     
        # 在輸出前增加的偏置
        
        self.dW,self.dU,self.db,self.dV,self.dc = torch.zeros_like(self.W),\
        torch.zeros_like(self.U),torch.zeros_like(self.b),torch.zeros_like(self.V),\
        torch.zeros_like(self.c)
        
        # 輸入
        self.h = None
        # hidden 的值
        
        self.input_x_list = []
        self.h_list = []
        self.dtanh_list = []
        
        self.tanh = Tanh()
        
    def forward(self, x):
        '''
        輸入x, size(n,m), 表示n條m維的輸入
        輸出o, 結合h和x運算得到的輸出值
        '''
        
        # forward計算
        n,m = x.shape
        if type(self.h)==type(None):
            self.h = torch.zeros(n,self.hidden_sz)
        y = x.mm(self.U)+self.h.mm(self.W)+self.b
        self.input_x_list.append(x)
        # 記錄一下舊的h的值
        self.h_list.append(self.h.clone())
        self.h = self.tanh(y)
        o = self.h.mm(self.V)+self.c
        
        # 計算部分梯度
        dtan = self.tanh.backward(1)
        self.dtanh_list.append(dtan)
        
        return o
    
    def backward(self, dout):
        # 計算dL/dV
        self.dV += self.h.T.mm(dout)
        # 計算dL/dc
        self.dc += torch.sum(dout, axis = 0)
        # 計算dL/dh_i
        dh = dout.mm(self.V.T)
        for i in range(len(self.dtanh_list)-1,-1,-1):
            dv = self.dtanh_list[i]*dh
            self.dW += self.h_list[i].T.mm(dv)
            self.dU += self.input_x_list[i].T.mm(dv)
            self.db += torch.sum(dv, axis = 0)
            dh = dv.mm(self.W.T)
            
            
    
    def clear(self):
        # 清空所有記憶體
        self.dW,self.dU,self.db,self.dV,self.dc = torch.zeros_like(self.W),\
        torch.zeros_like(self.U),torch.zeros_like(self.b),\
        torch.zeros_like(self.V),torch.zeros_like(self.c)
        
        self.input_x_list.clear()
        self.h_list.clear()
        self.dtanh_list.clear()
        
        self.h = None
    
    def clip(self):
        # 裁剪dW等參數, 如果絕對值過大就裁剪掉
        self.dW = torch.tensor(np.clip(self.dW.numpy(),-5,5))
        self.dU = torch.tensor(np.clip(self.dU.numpy(),-5,5))
        self.db = torch.tensor(np.clip(self.db.numpy(),-5,5))
        self.dV = torch.tensor(np.clip(self.dV.numpy(),-5,5))
        self.dc = torch.tensor(np.clip(self.dc.numpy(),-5,5))
        
    
    
    def update(self):
        self.clip()
        
        self.W -= self.lr*self.dW
        self.U -= self.lr*self.dU
        self.b -= self.lr*self.db
        self.V -= self.lr*self.dV
        self.c -= self.lr*self.dc
        
        self.clear()
        
    def __call__(self, X):
        return self.forward(X)

梯度驗證

和我們實現卷積神經網絡時一樣, 使用Pytorch和我們自己寫的RNN做同樣的事情, 並比較兩者反向傳播計算得到的梯度.

L = 5
n = 2

W = torch.randn(2,2)
U = torch.randn(2,2)
b = torch.randn(2)
V = torch.randn(2,1)
c = torch.randn(1)
h = torch.zeros(2,2)

W.requires_grad = True
U.requires_grad = True
b.requires_grad = True
V.requires_grad = True
c.requires_grad = True
h.requires_grad = True

input_X = torch.rand(L,2,2)
target = torch.rand(L,2,1)
loss = 0.


for i in range(L):
    v = input_X[i].mm(U)+h.mm(W)+b
    h = torch.tanh(v)
    o = h.mm(V)+c
    loss += F.mse_loss(o,target[i])/L
loss.backward()
print(W.grad)
print(U.grad)
print(b.grad)
print(V.grad)
print(c.grad)


my_rnn = RNN(2,2,1,0)
my_rnn.W = W.detach()
my_rnn.U = U.detach()
my_rnn.b = b.detach()
my_rnn.V = V.detach()
my_rnn.c = c.detach()

for i in range(L):
    o = my_rnn(input_X[i])
    my_rnn.backward(2*(o-target[i])/(L*n))

print(my_rnn.dW)
print(my_rnn.dU)
print(my_rnn.db)
print(my_rnn.dV)
print(my_rnn.dc)

實踐:時序相關序列預測

使用RNN處理時序數據有着非常好的優勢, 這裏我們用RNN預測正弦波信號. 我們在每個2π\pi週期中採樣8個點, 用多個這樣的週期數據讓RNN學習, 這樣RNN就會知道不論何時遇到序列信號, 都應該輸出正弦波.

model = RNN(1,20,1,LEARNING_RATE = 0.003)

X = torch.linspace(0,40*math.pi,160)
y = torch.sin(X)
X = X.reshape(20,-1,1)
X = X.permute(1,0,2)
y = y.reshape(20,-1,1)
y = y.permute(1,0,2)

L,n,m = X.shape

for epoch in range(10000):
    loss = 0.
    for i in range(L):
        out = model(X[i])
        dout = (out-y[i])  # mse loss
        loss += (dout**2).sum()/n
        model.backward(dout/L)
    loss /= L
    model.update()
    if (epoch+1)%500==0:
        print("epoch %d, loss %.4f"%(epoch+1,loss))

model.clear()

X_test = torch.linspace(40*math.pi,50*math.pi,40)
X_test = X_test.reshape(-1,1,1)

out_list = []
for i in range(len(X_test)):
    out = model(X_test[i])
    out_list.append(out.flatten().item())
    if (i+1)%8==0:
        model.clear()
    
xx = np.linspace(40*math.pi,50*math.pi,40)
plt.plot(xx,out_list)
plt.plot(xx,np.sin(xx))

我們用RNN預測後幾個週期的信號, 可以看見
在這裏插入圖片描述
RNN在3個週期以內都可以很好地fit真實正弦波, 在後面的週期逐漸發生偏移. 因爲與我們輸入x相乘的矩陣U並沒有完全被學習成0矩陣, 這應該是偏移的來源. 但是我們已經能看出RNN的時序預測能力.

雙向RNN

在這裏插入圖片描述
上面的RNN保證了, t時刻的輸出和1~t時刻的所有輸入都有關. 但是在做一些更復雜的問題時, 我們會希望RNN在任意時刻的輸出和整個序列都相關. 比如我們會希望預測詞性, 那麼一個詞的詞性不但要看上文, 還要看下文. 這時我們就會用這樣的雙向RNN架構. 它實際上是把兩個上面的單向RNN拼起來, 但是輸出會有一些變換, 我們的輸出計算同時考慮兩個RNN的當前隱層.
o=h1tV1+h2tV2+c o = h_{1t}V_1+h_{2t}V_2+c
反向傳播和上面大同小異, 幾乎沒有區別, 但是這樣的架構就可以幫我們看一整個序列.

LSTM

在這裏插入圖片描述
上圖是一個LSTM長短期記憶的結構體,它的引入是爲了解決simple RNN的梯度問題, rnn的最大缺點是梯度不平滑,因爲存儲單元在每一時刻,都會被另一個新的值完全覆蓋(賦值操作),雖然新的h和過去的h也有相關性,但隨着時間的推移,h將產生非常大的變化造成進行BPTT傳播時,很容易出現梯度消失和梯度爆炸。
LSTM的h單元(圖中的c單元)也會變化,但是每次變化要經過forget gate,這是一個相對更爲平滑的過程,而且實踐中forget的值會被設計的“不經常忘記”,也就使得記憶很容易長期保存下來,而新的記憶是以相加的形式得到的,這種形式的rnn更容易訓練。
在這裏插入圖片描述
簡單介紹下lstm的計算過程。每次前向傳播,我們把上一次的輸出h和本次輸入x拼起來,得到新的輸入向量x。然後x通過線性組合,成爲每一個gate的開關信號. 第一個gate是input gate,第二個gate是forget gate,第三個gate是output gate. 輸入向量被線性變化並加在幾個門上,首先得到輸入z=sigmoid(Winx)Wxz = sigmoid(W_{in}x)* Wx 然後遺忘門和輸入1共同決定現在的c存儲c=z+sigmoid(Wfgx)cc = z+sigmoid(W_{fg}x) * c 再然後,輸出門限制真正的輸出。o=sigmoid(Woutx)tanh(c)o = sigmoid(W_{out}x) * tanh(c)再到輸出層即可,輸出層的一個線性變換加softmaxy=softmax(Uo)y = softmax(Uo)
看起來運算好像很複雜, 但是其實是可以做的, 而且不算困難. 使用LSTM就能讓RNN更有效的在大部分任務上訓練, 缺點只是多了3倍的參數.

常用架構

因爲RNN允許多輸入多輸出, 它有着極爲多樣化的架構. 在這裏插入圖片描述
我們可以只用一個輸入, 後面的輸入都爲0, 也就是RNN輸出只取決於第一個輸入和隱層(one to many), 這樣的架構常用在語言模型. 我們可以一直接受輸入, 只在最後輸出一次,(many to one), 這種架構常用在序列分類模型. 還有更多樣的many to many, 那就能做更多的事情了.

小結

這次我們實現了RNN, 到此, 最基礎的深度模型框架就算開發完畢了. 我們已經學習了DNN全連接架構, 它能處理很多傳統模式識別問題, 而且在數據量較大時表現超過大多數傳統機器學習算法. 我們學習了CNN卷積神經網絡, CNN在圖像數據之類的局部特徵數據表現比DNN更好. 我們還學習了RNN, RNN廣泛應用於語言處理, 並且能幫助我們處理很長的序列數據.
後面我計劃至少要再講一下注意力機制的框架開發, 還有一些比較重要的深度學習架構. 比如生成模型GAN, 基於RNN的Seq2Seq, 目標檢測RCNN等等. 來日方長, 喜歡的朋友點個贊吧.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章