快速矩陣乘法的研究——上

快速矩陣乘法的研究

最近的工作主要在於深度學習框架的性能優化。深度學習框架在工程的優化(內存池、SIMD、彙編、GPU、DSP等等)做到接近極限之後,突破點便集中於算法。

深度學習的性能瓶頸主要在於卷積,卷積的運算方法主要是通過 Im2Col / Winograd / FFT 轉化爲矩陣乘,完成矩陣乘法之後,再轉化爲目標結果。

深度學習框架的輸入是算法工程產出的網絡模型,而目前網絡模型都漸漸地轉變爲 mobilenet 那樣 1x1 convolution + depthwise 的形式,在精度幾乎無損的情況下,既減少了計算量,又減少了模型體積。而這類網絡模型,都以 1x1 卷積爲主要耗時點。

對 1x1 卷積而言,其本身就是一個矩陣乘法,FFT / Winograd 等卷積算法已經失去價值,因此研讀了一些矩陣乘法相關的論文,整理如下。

傳統矩陣乘算法

定義

在 1968 年之前,矩陣乘算法只有按定義實現的傳統算法,:
設:
A=(a11a12...a21a22............an1an2...)B=(b11b12...b21b22............bn1bn2...)A=\begin{pmatrix} a_{11} &a_{12} &... \\ a_{21} &a_{22} &... \\ ... & ... & ... \\ a_{n1} & a_{n2} & ... \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} &... \\ b_{21} &b_{22} &... \\ ... & ... & ... \\ b_{n1} & b_{n2} & ... \\ \end{pmatrix}
AB 爲其乘積,則:
[AB]pq=i=1napibiq[AB]_{pq} = \sum_{i=1}^{n}a_{pi}b_{iq}

很明顯,它是一個 n3n^3複雜度的算法,需要 n3n^3 次乘法和 n3n2n^3-n^2次加法。

矩陣乘表示

C=ABC = AB,A 爲 ele*l的矩陣,B 爲 lhl*h的矩陣,則稱這個矩陣乘是一個 [e,l,h][e, l, h] 的矩陣乘。

快速矩陣乘法的初步探索

Winograd 算法

請注意,這個不是我們通常所說的卷積優化算法,只是同一個人(Winograd大神)在 1968 年提出一種減少乘法數的矩陣乘算法。

其思路是通過兩次 n2n^2 的乘法預處理,將規模大的矩陣乘法減少一半,但相應的加法增加一半。爲了說明簡單,這裏假定nn爲偶數。
θp=j=1n/2(ap,2j1ap,2j)γq=j=1n/2(b2j1,qb2j,q)[AB]pq=j=1n/2(ap,2j1+b2j,q)(ap,2j+b2j1,q)θpγq\theta_p = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1} a_{p, 2j}) \\\gamma_q = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(b_{2j-1, q}b_{2j, q}) \\ [AB]_{pq} = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1}+b_{2j, q})(a_{p, 2j}+b_{2j-1, q}) - \theta_p - \gamma_q

這個算法沒有降低矩陣乘法的階(還是n3n^3),只是以廉價計算(加法)替代昂貴運算(乘法),需要根據具體的硬件去判斷是否可應用。ARM 架構的 CPU,對量化矩陣乘有幫助,但對浮點矩陣乘沒有用。

Strassen 矩陣乘算法

Strassen 矩陣乘的思路是通過加減變換,將一個 [2,2,2][2, 2, 2]的矩陣乘法所用的乘法數由8降到7,並且遞歸使用,降低矩陣乘法的階數:n3n^3變成n2.81n^{2.81}
A=(a11a12a21a22)B=(b11b12b21b22)AB=(c11c12c21c22) A=\begin{pmatrix} a_{11} &a_{12} \\ a_{21} &a_{22} \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} \\ b_{21} &b_{22} \\ \end{pmatrix} AB=\begin{pmatrix} c_{11} &c_{12} \\ c_{21} &c_{22} \\ \end{pmatrix}

v1=(a11+a22)(b11+b22)v2=(a21+a22)(b11)v3=(a11)(b12b22)v4=(a22)(b21b11)v5=(a11+a12)(b22)v6=(a21a11)(b11+b12)v7=(a12a22)(b21+b22)v_1 = (a_{11}+a_{22})(b_{11}+b_{22})\\ v_2 = (a_{21}+a_{22})(b_{11})\\v_3 = (a_{11})(b_{12}-b_{22})\\v_4 = (a_{22})(b_{21}-b_{11})\\v_5 = (a_{11}+a_{12})(b_{22})\\v_6 = (a_{21}-a_{11})(b_{11}+b_{12})\\v_7 = (a_{12}-a_{22})(b_{21}+b_{22})

c11=v1+v4v5+v7c21=v2+v4c12=v3+v5c22=v1+v3v2+v6c_{11} = v_1+v_4-v_5+v_7\\c_{21} = v_2+v_4\\c_{12} = v_3+v_5\\c_{22} = v_1+v_3-v_2+v_6

請注意,其中每個元素(a11,b12,c22a_{11}, b_{12}, c_{22}等等)不限於實數,可以是一個矩陣。因爲矩陣乘法滿足分配率與結合率。這樣算法就有了脫離硬件的普適價值,因爲矩陣加減的複雜度(n2n^2)遠低於矩陣乘(n3n^3)

Winograd 在 Strassen 的基礎上對它的算法進行了改進,減少了加減數(18->15),這個也成爲最常用的 Strassen 矩陣乘法應用。

三線性表示

爲了方便矩陣乘算法的研究,人們提出一種表示矩陣乘算法的形式,叫“Trilinear-form”,即三線性形式。
我們先以 Strassen 算法爲例,它的三線性形式是:
i=12j=12k=12aijbjkcik=(a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22)\sum_{i=1}^2\sum_{j=1}^2\sum_{k=1}^2 a_{ij}b_{jk}c_{ik} = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})

怎麼看這個公式呢,它其實是按 Trace(ABC)=ABTrace(ABC) = AB 的原理去表示的。兩個矩陣的乘積,等效於三個矩陣乘積的跡。在上面公式中,如果我們要算出 c11c_{11} 的解法,就將 c11c_{11} 設成 1,其他的 c 值,c12,c21,c22c_{12}, c_{21}, c_{22} 全設成 0 ,然後將對應的項相加即可。

這個算式總共有7項,這個 7 我們稱之爲 Rank (階)

APA——矩陣乘算法的突破

APA,即 Any Precision Algorithm,是把矩陣乘法階數繼續往下降的重要思想,基本思路是先給出近似的矩陣乘法表達式,然後在多階張量積之後轉換爲準確的矩陣乘法。

張量積

我們來看 Strassen 矩陣乘法的表達式:
λ=(a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22)\lambda = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22})

對其平方:
λ2=((a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22))2\lambda^2 = ((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}))^2

這是個多項式乘法,不難知λ2\lambda^272=497^2=49 項,我們來看其中一項:
((a11)(b12b22)(c12+c22))((a11+a12)(b22)(c11+c12))=(a11a11+a11a12)(b12b22b22b22)(c12c11+c12c12c22c11+c22c12)((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}))((a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}))=(a_{11}a_{11}+a_{11}a_{12})(b_{12}b_{22}-b_{22}b_{22})(-c_{12}c_{11}+c_{12}c_{12}-c_{22}c_{11}+c_{22}c_{12})
(依然是將a, b, c 分別組合在一起)

a,b,ca, b, c間的相乘,如a11a12a_{11}a_{12},我們將其替代爲直和:a1112a_{1112},其含義可以這麼理解,在a11a_{11}的區域(左上角)中,再劃分爲四塊,取其a12a_{12}的區域(右上角)。
不難證明,我們通過這個多項式平方後得到的三線性形式,等效於一個 [4,4,4][4, 4, 4] 的矩陣乘法。

類似地,我們可以對矩陣乘法的三線性形式進行立方,n次方,以及兩個不同的三線性形式乘積,這一系列操作可由“張量積”概括。

APA

Any Precision Algorithm(APA),即任意精度算法,通過在算式中引入一個可配置的實數λ\lambda,得到更好的簡化效果。

下面的式子近似用21項表示了一個[3,3,3][3, 3, 3]的矩陣乘法

F1(λ)=(a11+λ2a12)(λ2b11+b21)c11+(a21+λ2a22)(λ2b12+b22)c22+(a31+λ2a32)(λ2b13+b23)c33a11(b21+b31)(c11+c12+c13)a21(b22+b32)(c21+c22+c23)a31(b23+b33)(c31+c32+c33)+(a11+λ2a22)(b21λb12)c12+(a21+λ2a12)(b22λb11)c21+(a11+λ2a32)(b21λb13)c13+(a31+λ2a12)(b23λb11)c31+(a21+λ2a32)(b22λb13)c23+(a31+λ2a22)(b23λb12)c32+(a11+λ2a23)(b31+λb12)(c12+λc21)+(a21+λ2a13)(b32+λb11)(c21+λc12)+(a11+λ2a33)(b31+λb13)(c13+λc31)+(a31+λ2a13)(b33+λb12)(c31+λc13)+(a21+λ2a33)(b32+λb13)(c23+λc32)+(a31+λ2a23)(b33+λb12)(c32+λc23)+(a11+λ2a13)b31(c11λc31λc21)+(a21+λ2a23)b32(c22λc32λc12)+(a31+λ2a33)b33(c33λc13λc23)=λ2(Trace(ABC)+λG(λ))F_1(\lambda) = (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11}\\+(a_{21}+\lambda^2a_{22})(\lambda^2b_{12}+b_{22})c_{22}+(a_{31}+\lambda^2a_{32})(\lambda^2b_{13}+b_{23})c_{33}-a_{11}(b_{21}+b_{31})(c_{11}+c_{12}+c_{13})-a_{21}(b_{22}+b_{32})(c_{21}+c_{22}+c_{23})-a_{31}(b_{23}+b_{33})(c_{31}+c_{32}+c_{33})+(a_{11}+\lambda^2a_{22})(b_{21}-\lambda b_{12})c_{12}+(a_{21}+\lambda^2a_{12})(b_{22}-\lambda b_{11})c_{21}+(a_{11}+\lambda^2a_{32})(b_{21}-\lambda b_{13})c_{13}+(a_{31}+\lambda^2a_{12})(b_{23}-\lambda b_{11})c_{31}+(a_{21}+\lambda^2a_{32})(b_{22}-\lambda b_{13})c_{23}+(a_{31}+\lambda^2a_{22})(b_{23}-\lambda b_{12})c_{32}+(a_{11}+\lambda^2a_{23})(b_{31}+\lambda b_{12})(c_{12}+\lambda c_{21})+(a_{21}+\lambda^2a_{13})(b_{32}+\lambda b_{11})(c_{21}+\lambda c_{12})+(a_{11}+\lambda^2a_{33})(b_{31}+\lambda b_{13})(c_{13}+\lambda c_{31})+(a_{31}+\lambda^2a_{13})(b_{33}+\lambda b_{12})(c_{31}+\lambda c_{13})+(a_{21}+\lambda^2a_{33})(b_{32}+\lambda b_{13})(c_{23}+\lambda c_{32})+(a_{31}+\lambda^2a_{23})(b_{33}+\lambda b_{12})(c_{32}+\lambda c_{23})+(a_{11}+\lambda^2a_{13})b_{31}(c_{11}-\lambda c_{31}-\lambda c_{21})+(a_{21}+\lambda^2a_{23})b_{32}(c_{22}-\lambda c_{32}-\lambda c_{12})+(a_{31}+\lambda^2a_{33})b_{33}(c_{33}-\lambda c_{13}-\lambda c_{23}) = \lambda^2 (Trace(ABC)+\lambda G(\lambda))

λ\lambda趨於無窮小時,其誤差也趨於無窮小,因此我們可以設定任意的精度去使用它,這就是 APA 的由來。

對於 APA 算法,多項式的個數我們稱之爲 Border Rank,上述算式表示了一個[3,3,3][3, 3, 3]的矩陣乘法,在λ3\lambda ^3的基礎上分出誤差,我們稱之爲一個降解:[3,3,3]321[3, 3, 3] \unlhd_3 21

現在我們來看怎麼把上面的 APA 算法變成準確算法。

直觀的做法就是把λ2\lambda^2項取出來,如:(a11+λ2a12)(λ2b11+b21)c11(a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11},取出 λ2a11b11c11+λ2a12b21c11\lambda^2a_{11}b_{11}c_{11}+\lambda^2a_{12}b_{21}c_{11},代價就是增加了多項式,不難證明,我們最多會增加到 2(2+1)/2=32(2+1)/2=3倍的多項式個數。

無疑,這樣做肯定虧了,321=63>333=273*21=63 > 3*3*3=27,我們需要施個魔法,就是張量積。

對上面APA 算法進行n次張量積之後,我們可以得到3n3^n大小的矩陣乘算法的降解:[3n,3n,3n]2n+121n[3^n, 3^n, 3^n] \unlhd_{2n+1} 21^n

這時候我們再來取,就不一樣了,其階數變成了:
n(2n+1)21nn(2n+1)21^n
很明顯,當 n 足夠大時,n(2n+1)n(2n+1) 和指數項相比可忽略,這樣我們就得到了更好的準確算法,其階數爲:
3ln(21)/ln(27)2.773ln(21)/ln(27)\approx2.77

下篇內容:
1、組合矩陣乘
2、漸近和定理
3、Strassen構造
4、Coppersmith–Winograd 算法

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章