《算法導論》——矩陣乘法Strassen算法

注:本文爲《算法導論》中分治相關內容的筆記。對此感興趣的讀者還望支持原作者。

矩陣乘法

接觸過線性代數的讀者,對於矩陣乘法想必一定不陌生。若A=(aij)A=(a_{ij})B=(bij)B=(b_{ij})nnn*n的方陣,則對i,j,,ni, j, \ldots, n,定義乘積C=ABC=A \cdot B中的元素cijc_{ij}爲:

cij=k=1naikbkjc_{ij}=\sum_{k=1}^{n}a_{ik}b_{kj}

因此,我們可以根據矩陣乘法的定義給出矩陣乘法的僞代碼。它接收nnn * n的矩陣AABB,返回它們的乘積——nnn * n的矩陣CC,並且假設每個矩陣都有一個屬性rowsrows,表示矩陣的行數。

樸素算法

不難看出,由於三重for循環都恰好執行nn步,而第7行每次執行都花費常量時間。因此,SQUARE-MATRIX-MULTIPLY的時間複雜度爲θ(n3)\theta (n^3),即矩陣乘法的樸素實現需要花費θ(n3)\theta (n^3)時間。你可能因此認爲任何矩陣乘法都要花費Ω(n3)\Omega (n^3)時間,因爲矩陣乘法的自然定義就需要進行這麼多次的標量乘法。而在學術界,也的確在很長一段時間內,很少人敢設想一個算法能漸近快於平凡算法SQUARE-MATRIX-MULTIPLY,直至Strassen大神的出現。

算法流程

Strassen算法採用分治法解決矩陣乘積問題,並通過排列組合的技巧使得分治法產生的遞歸樹不那麼“茂盛”以減少矩陣乘法的次數。Strassen算法並不直觀,它包含4個步驟:

  1. 將輸入矩陣ABA、B和輸出矩陣CC通過以下方式分解爲n2n2\frac{n}{2} * \frac{n}{2}的子矩陣;
    A=[A11A12A21A22],B=[B11B12B21B22],C=[C11C12C21C22]A = \left [ \begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \\ \end{matrix} \right ], B = \left [ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right ], C = \left [ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right ]

  2. 創建10個n2n2\frac{n}{2} * \frac{n}{2}的矩陣S1,S2,,S10S_1, S_2, \ldots , S_{10},每個矩陣保存步驟1中創建的兩個子矩陣的和或差,時間複雜度爲Θ(n2)\Theta (n^2)

  3. 用步驟1中創建的子矩陣和步驟2中創建的10個矩陣,遞歸地計算7個矩陣積P1,P2,,P7P_1, P_2, \ldots , P_7。每個矩陣PiP_i都是n2n2\frac{n}{2} * \frac{n}{2}的;

  4. 通過PiP_i矩陣的不同組合進行加減計算,計算出矩陣CC的子矩陣C11,C12,C21,C22C_{11}, C_{12}, C_{21}, C_{22},時間複雜度爲Θ(n2)\Theta(n^2)

是不是感覺很抽象?一頓猛如虎的操作,就能完成矩陣乘積計算了?沒錯,就是這麼。接下來,爲了幫助大家掌握這種操作,就再看看Strassen算法的細節。在步驟2中,創建如下10個矩陣:

S1=B12B22S_1 = B_{12} - B_{22}

S2=A11+A12S_2 = A_{11} + A_{12}

S3=A21+A22S_3 = A_{21} + A_{22}

S4=B21B11S_4 = B_{21} - B_{11}

S5=A11+A22S_5 = A_{11} + A_{22}

S6=B11+B22S_6 = B_{11} + B_{22}

S7=A12A22S_7 = A_{12} - A_{22}

S8=B21+B22S_8 = B_{21} + B_{22}

S9=A11A21S_9 = A_{11} - A_{21}

S10=B11+B22S_{10} = B_{11} + B_{22}

由於必須進行10次n2n2\frac{n}{2} * \frac{n}{2}的加減法,因此,該步驟花費Θ(n2)\Theta(n^2)

在步驟三中,遞歸地計算7次n2n2\frac{n}{2} * \frac{n}{2}矩陣的乘法,如下所示:

P1=A11S1=A11B12A11B22P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22}

P2=S2B22=A11B22+A12B22P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} + A_{12} \cdot B_{22}

P3=S3B11=A21B11+A22B11P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11}

P4=A22S4=A22B21A22B11P_4 = A_{22} \cdot S_4 = A_{22} \cdot B_{21} - A_{22} \cdot B_{11}

P5=S5S6=A11B11+A11B22+A22B11+A22B22P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{11} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22}

P6=S7S8=A12B21+A12B22A22B21A22B22P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A_{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22}

P7=S9S10=A11B11+A11B12A21B11A21B12P_7 = S_9 \cdot S_10 = A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12}

步驟4對步驟3創建的PiP_i矩陣進行加減法運算,計算出CC的4個n2n2\frac{n}{2} * \frac{n}{2}的子矩陣。

C11=P5+P4P2+P6=A11B11+A12B21C_{11} = P_5 + P_4 - P_2 + P_6 = A_{11} \cdot B_{11} + A_{12} \cdot B_{21}

C12=P1+P2=A11B12+A12B22C_{12} = P_1 + P_2 = A_{11} \cdot B_{12} + A_{12} \cdot B_{22}

C21=P3+P4=A21B11+A22B21C_{21} = P_3 + P_4 = A_{21} \cdot B_{11} + A_{22} \cdot B_{21}

C22=P5+P1P3P7=A22B22+A21B12C_{22} = P_5 + P_1 - P_3 - P_7 = A_{22} \cdot B_{22} + A_{21} \cdot B_{12}

如此,我們便獲得矩陣AABB的乘積——矩陣CC

算法分析

之前說過,Strassen算法的時間複雜度是優於樸素計算的,可是,它到底是多少呢?我們不妨再回到Strassen算法的流程。當n>1n > 1時,步驟1、2和4共花費θ(n2)\theta(n^2)時間,步驟3要求7次n2n2\frac{n}{2} * \frac{n}{2}矩陣的乘法。因此,我們得到如下描述Strassen算法運行時間T(n)T(n)的遞歸式:

T(n)={θ(1)n=17T(n/2)+θ(n2)n>1 T(n)=\left\{ \begin{aligned} & \theta(1) & 若n = 1\\ & 7T(n/2) + \theta(n^2) & 若n > 1\\ \end{aligned} \right.

求解上式可得,T(n)=θ(nlg7)T(n) = \theta(n^{\lg7})

算法實現

廢話千句,不如代碼兩行,接下來直接上Strassen算法的實現。(注意,如果nn不是2的冪,可以採取對原矩陣填充0的方式,使nn擴展到2的冪)。

Strassen算法

算法總結

Strassen算法發表於1969年,它的發表引起了很大的轟動。在此之前,很少人敢設想一個算法能漸近快於平凡算法SQUARE-MATRIX-MULTIPLY。矩陣乘法的上界自此被改進了。到目前爲止,nnn*n矩陣相乘的漸近複雜性最優的算法是Coppersmith和Winograd提出的,運行時間是O(n2.376)O(n^{2.376})

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章