小白都能理解的矩陣與向量求導鏈式法則

0.前言

深度學習中最常見的是各種向量還有矩陣運算,經常會涉及到求導操作。因此準確理解向量矩陣的求導操作就顯得非常重要,對我們推導計算過程以及代碼書寫覈對有非常大的幫助。
神經網絡中最常見的操作爲向量,矩陣乘法,在求導的時候經常需要用到鏈式法則,鏈式法則在計算過程中會稍微麻煩,下面我們來詳細推導一下,推導過程全程簡單明瞭,稍微有點數學基礎的同學都能看明白。

1.標量對標量的鏈式求導

假設x, y, z都爲標量(或者說一維向量),鏈式關係爲x -> y -> z。根據高數中的鏈式法則
zx=zyyx\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x}

上面的計算過程很簡單,不多解釋。

2.向量對向量鏈式求導

假設x,y,z都爲向量,鏈式關係爲x -> y -> z。如果我們要求zx\frac{\partial z}{\partial x},可以直接用鏈接法則求導
zx=zyyx\frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial x}

假設x, y, z的維度分別爲m, n, p,zx\frac{\partial z}{\partial x}的維度爲p * m,而zy\frac{\partial z}{\partial y}的維度爲p * n,yx\frac{\partial y}{\partial x}的維度爲n * m,p * n與n * m的維度剛好爲p * m,與左邊相同。

3.標量對多向量的鏈式求導

在深度學習中,一般我們的損失函數爲一個標量函數,比如MSE或者Cross Entropy,因此最後求導的目標函數爲標量。
假設我們最終優化的目標爲z是個標量,x,y分爲是m,n維向量,依賴關係爲x->y->z。現在需要求的是zx\frac{\partial z}{\partial x},維度爲m * 1。
易知有zy\frac{\partial z}{\partial y}爲n * 1, yx\frac{\partial y}{\partial x}爲n * m,則(yx)Tzy(\frac{\partial y}{\partial x}) ^ T \cdot \frac{\partial z}{\partial y}的維度爲m * 1,與左邊能對上。
因此有
zx=(yx)Tzy\frac{\partial z}{\partial x} = (\frac{\partial y}{\partial x}) ^ T \cdot \frac{\partial z}{\partial y}
擴展到多個向量
y1 -> y2 -> y3 -> …-> yn -> z
zy1=(ynyn1yn1yn2y2y1)Tzyn \frac{\partial z}{\partial y_1} = (\frac{\partial y_n}{\partial y_{n-1}} \cdot \frac{\partial y_{n-1}}{\partial y_{n-2}} \cdots \frac{\partial y_2}{\partial y_1}) ^ T \cdot \frac{\partial z}{\partial y_n}

以常見的最小二乘求導爲例:
C=(Xθy)T(Xθy)C = (X\theta - y) ^ T (X\theta - y)
損失函數C是個標量,假設X爲m*n的矩陣,θ\thetan1n*1的向量,我們要求C對θ\theta的導數,令z=Xθyz = X\theta - yC=zTzC = z^Tz,由上面的連式關係
Cθ=(zθ)TCz=2XT(Xθy)\frac{\partial C}{\partial \theta} = (\frac{\partial z}{\partial \theta})^T \cdot \frac{\partial C}{\partial z} = 2X^T(X \theta - y)

覈對一下維度
Cθ\frac{\partial C}{\partial \theta}是n * 1, XTX^T是n * m,XθyX \theta - y是m * 1, XT(Xθy)X^T (X \theta - y)是n * 1,能與左邊對上。

其中
(Xθy)θ=X\frac{\partial (X \theta - y)}{\partial \theta} = X
(zTz)z=2z\frac{\partial (z^Tz)}{\partial z} = 2z

XθyX \theta - y爲m * 1, θ\theta爲n * 1, X的維度剛好爲m * n。
zTzz^Tz是個標量,對zz求導結果爲2z2z,與zz的維度一致。

所以最小二乘最優解的矩陣表達式爲
2XT(Xθy)=02X^T(X \theta - y) = 0
θ=(XTX)1XTy\theta = (X^TX)^{-1}X^Ty

4.標量對多矩陣鏈式求導

神經網絡中,最常見的計算方式是Y=WX+bY = WX + b,其中WW爲權值矩陣。
看個更爲常規的描述:
假設z=f(Y)z = f(Y)Y=AX+BY = AX + B,其中A爲m * k矩陣,X爲k * 1向量,B爲m * 1向量,那麼Y也爲m * 1向量,z爲一個標量

如果要求zX\frac{\partial z}{\partial X},結果爲k * 1的維度。zY\frac{\partial z}{\partial Y}的維度爲m * 1, A的維度爲m * k
zX=ATzY\frac{\partial z}{\partial X} = A ^ T \cdot \frac{\partial z}{\partial Y}

左邊的維度爲k * 1,右邊的維度爲k * m 與m * 1相乘,也是k * 1,剛好能對上。

當X爲矩陣時,按同樣的方式進行推導可以得到一樣的結論。

如果要求zA\frac{\partial z}{\partial A},結果爲m * k矩陣。zY\frac{\partial z}{\partial Y}的維度爲m * 1, X的維度爲k * 1,
zA=zYXT\frac{\partial z}{\partial A} = \frac{\partial z}{\partial Y} \cdot X^T
當X爲矩陣時,按同樣的方式進行推導可以得到一樣的結論。

5.總結

一句話總結就是:標量對向量或者矩陣進行鏈式求導的時候,按照維度將結果對其就特別容易推導。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章