矩陣求導

矩陣在機器學習中用的比較多,比如反向傳播的時候,本來矩陣相關運算的反向傳播很簡單,但是網上的資料都寫的很複雜,我儘量以簡單的形式來表述矩陣相關運算的反向傳播過程。

  • 矩陣相乘的反向傳播
  • 矩陣點乘的反向傳播
  • 矩陣乘向量的反向傳播
  • 向量乘矩陣的反向傳播
  • 矩陣乘標量的反向傳播

本文用x表示標量,\textbf{x}表示向量,X表示矩陣,\mathbf{x}表示列向量,\mathbf{x^{^{T}}}表示行向量

矩陣相乘反向傳播

X爲3行2列的矩陣,Y爲2行4列的矩陣,ZXY的結果矩陣,爲3行4列,將其分別表示出來,則:

X=\begin{bmatrix} x_{11} &x_{12} \\ x_{21} &x_{22} \\ x_{31} &x_{32} \end{bmatrix} 

Y=\begin{bmatrix} y_{11} &y_{12} &y_{13} &y_{14} \\ y_{21} &y_{22} &y_{23} &y_{24} \end{bmatrix}

Z=\begin{bmatrix} z_{11} &z_{12} &z_{13} &z_{14} \\ z_{21} &z_{22} &z_{23} &z_{24} \\ z_{31} &z_{32} &z_{33} &z_{34} \end{bmatrix}

其中z_{ij}=\sum_{k=1}^{2} x_{ik}y_{kj},i=1,2,3; j=1,2,3,4

則loss對X的偏導數爲:

{\frac{\partial l}{\partial X}} =\begin{bmatrix} ^{\frac{\partial l}{\partial x_{11}}}&^{\frac{\partial l}{\partial x_{12}}} \\ ^{\frac{\partial l}{\partial x_{21}}}&^{\frac{\partial l}{\partial x_{22}}} \\ ^{\frac{\partial l}{\partial x_{31}}}&^{\frac{\partial l}{\partial x_{32}}} \end{bmatrix} = \begin{bmatrix} \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{11}}} & \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{12}}}\\ \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{21}}}& \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{22}}}\\ \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{31}}}& \sum_{i=1,j=1}^{3,4}{\frac{\partial l}{\partial z_{ij}}\frac{\partial z_{ij}}{\partial x_{32}}} \end{bmatrix}                                       (1)

 

所以{\frac{\partial l}{\partial X}}={\frac{\partial l}{\partial Z}} {\frac{\partial Z}{\partial X}} ={\frac{\partial l}{\partial Z}} Y^{T}

同理可證{\frac{\partial l}{\partial Y}}={\frac{\partial l}{\partial Z}} {\frac{\partial Z}{\partial Y}} =X^{T}{\frac{\partial l}{\partial Z}}

矩陣點乘的反向傳播

同理,Z=Xdot(Y),設

X=\begin{bmatrix} x_{11} &x_{12} \\ x_{21} &x_{22} \\ x_{31} &x_{32} \end{bmatrix} 

Y=\begin{bmatrix} y_{11} &y_{12} \\ y_{21} &y_{22} \\ y_{31} &y_{32} \end{bmatrix} 

Z=\begin{bmatrix} z_{11} &z_{12} \\ z_{21} &z_{22} \\ z_{31} &z_{32} \end{bmatrix}

其中 z_{ij}=x_{ij}y_{ij},i=1,2,3; j=1,2;

{\frac{\partial l}{\partial X}} =\begin{bmatrix} ^{\frac{\partial l}{\partial x_{11}}}&^{\frac{\partial l}{\partial x_{12}}} \\ ^{\frac{\partial l}{\partial x_{21}}}&^{\frac{\partial l}{\partial x_{22}}} \\ ^{\frac{\partial l}{\partial x_{31}}}&^{\frac{\partial l}{\partial x_{32}}} \end{bmatrix} = \begin{bmatrix} {\frac{\partial l}{\partial z_{11}}\frac{\partial z_{11}}{\partial x_{11}}} & {\frac{\partial l}{\partial z_{12}}\frac{\partial z_{12}}{\partial x_{12}}}\\ {\frac{\partial l}{\partial z_{21}}\frac{\partial z_{21}}{\partial x_{21}}}& {\frac{\partial l}{\partial z_{22}}\frac{\partial z_{22}}{\partial x_{22}}}\\ {\frac{\partial l}{\partial z_{31}}\frac{\partial z_{31}}{\partial x_{31}}}& {\frac{\partial l}{\partial z_{32}}\frac{\partial z_{32}}{\partial x_{32}}} \end{bmatrix}=\frac{\partial l}{\partial Z}dot(Y)

同理

{\frac{\partial l}{\partial Y}}=\frac{\partial l}{\partial Z}dot(X)

矩陣乘向量的反向傳播

\textbf{y}=W\textbf{x} , 設W爲3行2列,x爲2行1列,y3行1列:
W=\begin{bmatrix} w_{11} &w_{12} \\ w_{21} &w_{22} \\w_{31} &w_{32} \end{bmatrix}

\textbf{x}=\begin{bmatrix} x11\\ x21 \end{bmatrix}

\textbf{y}=\begin{bmatrix} y11\\ y21 \\y31\end{bmatrix}

{\frac{\partial l}{\partial \textbf{x}}} =\begin{bmatrix} ^{\frac{\partial l}{\partial x_{11}}}\\ ^{\frac{\partial l}{\partial x_{21}}} \end{bmatrix} = \begin{bmatrix} {\frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x_{11}}}+{\frac{\partial l}{\partial y_{21}}\frac{\partial y_{21}}{\partial x_{11}}} +{\frac{\partial l}{\partial y_{31}}\frac{\partial y_{31}}{\partial x_{11}}} \\ {\frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x_{21}}}+{\frac{\partial l}{\partial y_{21}}\frac{\partial y_{21}}{\partial x_{21}}}+{\frac{\partial l}{\partial y_{31}}\frac{\partial y_{31}}{\partial x_{21}}}\end{bmatrix}=\begin{bmatrix} {\frac{\partial l}{\partial y_{11}}}w_{11}+{\frac{\partial l}{\partial y_{21}}}w_{21}+{\frac{\partial l}{\partial y_{31}}}w_{31}\\ {\frac{\partial l}{\partial y_{11}}}w_{12}+{\frac{\partial l}{\partial y_{21}}}w_{22}+{\frac{\partial l}{\partial y_{31}}}w_{32} \end{bmatrix}=W^{T}{\frac{\partial l}{\partial \textbf{y}}}

同理可證{\frac{\partial l}{\partial W}} ={\frac{\partial l}{\partial \textbf{y}}}\mathbf{x^{T}}

向量乘矩陣的反向傳播

\textbf{y}=\textbf{x}W , 設W爲3行2列,x爲1行3列,y爲1行2列:
W=\begin{bmatrix} w_{11} &w_{12} \\ w_{21} &w_{22} \\w_{31} &w_{32} \end{bmatrix}

\textbf{x}=\begin{bmatrix} x11& x12 & x13\end{bmatrix}

\textbf{y}=\begin{bmatrix} y11& y12\end{bmatrix}

{\frac{\partial l}{\partial \textbf{x}}} =\begin{bmatrix} \frac{\partial l}{\partial x_{11}}& \frac{\partial l}{\partial x_{12}} & \frac{\partial l}{\partial x_{13}}\end{bmatrix} =\begin{bmatrix} \frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x_{11}}+ \frac{\partial l}{\partial y_{12}}\frac{\partial y_{12}}{\partial x_{11}}& \frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x_{12}}+\frac{\partial l}{\partial y_{12}}\frac{\partial y_{12}}{\partial x_{12}} & \frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x_{13}}+ \frac{\partial l}{\partial y_{12}}\frac{\partial y_{12}}{\partial x_{13}} \end{bmatrix}=\begin{bmatrix} \frac{\partial l}{\partial y_{11}}w_{11}+ \frac{\partial l}{\partial y_{12}}w_{12}& \frac{\partial l}{\partial y_{11}}w_{21}+ \frac{\partial l}{\partial y_{12}}w_{22} & \frac{\partial l}{\partial y_{11}}w_{31}+ \frac{\partial l}{\partial y_{12}}w_{32} \end{bmatrix}={\frac{\partial l}{\partial \textbf{y}}}W^{T}

同理可證{\frac{\partial l}{\partial W}} =\mathbf{x^{T}}{\frac{\partial l}{\partial \textbf{y}}}

矩陣乘標量的反向傳播

Y=xW , 設W爲3行2列,x爲標量,y爲3行2列:
W=\begin{bmatrix} w_{11} &w_{12} \\ w_{21} &w_{22} \\w_{31} &w_{32} \end{bmatrix}

x=x11

Y=\begin{bmatrix} y_{11} &y_{12} \\ y_{21} &y_{22} \\y_{31} &y_{32} \end{bmatrix}

{\frac{\partial l}{\partial x}} =\frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial x}+\frac{\partial l}{\partial y_{12}}\frac{\partial y_{12}}{\partial x}+\frac{\partial l}{\partial y_{21}}\frac{\partial y_{21}}{\partial x}+\frac{\partial l}{\partial y_{22}}\frac{\partial y_{22}}{\partial x}+\frac{\partial l}{\partial y_{31}}\frac{\partial y_{31}}{\partial x}+\frac{\partial l}{\partial y_{32}}\frac{\partial y_{32}}{\partial x}=\frac{\partial l}{\partial y_{11}}w_{11}+\frac{\partial l}{\partial y_{12}}w_{12}+\frac{\partial l}{\partial y_{21}}w_{21}+\frac{\partial l}{\partial y_{22}}w_{22}+\frac{\partial l}{\partial y_{31}}w_{31}+\frac{\partial l}{\partial y_{32}}w_{32}=\sum ({\frac{\partial l}{\partial y}}\cdot W)

{\frac{\partial l}{\partial W}} =\begin{bmatrix} ^{\frac{\partial l}{\partial w_{11}}}&^{\frac{\partial l}{\partial w_{12}}} \\ ^{\frac{\partial l}{\partial w_{21}}}&^{\frac{\partial l}{\partial w_{22}}} \\ ^{\frac{\partial l}{\partial w_{31}}}&^{\frac{\partial l}{\partial w_{32}}} \end{bmatrix} =\begin{bmatrix} \frac{\partial l}{\partial y_{11}}\frac{\partial y_{11}}{\partial w_{11}}& \frac{\partial l}{\partial y_{12}}\frac{\partial y_{12}}{\partial w_{12}} \\ \frac{\partial l}{\partial y_{21}}\frac{\partial y_{21}}{\partial w_{21}} &\frac{\partial l}{\partial y_{22}}\frac{\partial y_{22}}{\partial w_{22}} \\ \frac{\partial l}{\partial y_{31}}\frac{\partial y_{31}}{\partial w_{31}}& \frac{\partial l}{\partial y_{32}}\frac{\partial y_{32}}{\partial w_{32}} \end{bmatrix}=x\cdot {\frac{\partial l}{\partial Y}}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章