算法导论 — 4.2 矩阵乘法的Strassen算法

笔记

给定两个n×nn×n正方矩阵AABB,这两个矩阵的乘法定义为
      在这里插入图片描述
  其中
      在这里插入图片描述
  下面是矩阵乘法的伪代码。
  在这里插入图片描述
  很显然,执行SQUARE-MATRIX-MULTIPLY需要花费Θ(n3)Θ(n^3)时间。然而,有一种方法可以花费更少的时间,这就是Strassen算法,它本质上也是一种分治法,它的时间复杂度为Θ(nlg7)=O(n2.81)Θ(n^{{\rm lg}7}) = O(n^{2.81})
  在介绍Strassen算法之前,我们先尝试简单的分治法来计算矩阵乘法C=ABC = A•B。假定三个矩阵均为n×nn×n正方矩阵。并且为简化分析,假定nn22的幂。我们将ABA、BCC均分解为44(n/2)×(n/2)(n/2)×(n/2)的子矩阵。
    在这里插入图片描述
  于是矩阵乘法可以表示为
    在这里插入图片描述
  上面的矩阵乘法等价于下面44个式子。
      在这里插入图片描述
  上面每个式子都对应22(n/2)×(n/2)(n/2)×(n/2)矩阵乘法,以及11(n/2)×(n/2)(n/2)×(n/2)矩阵加法。根据以上分析,可以给出一个递归的分治算法。
  在这里插入图片描述
  现在分析这个简单分治法的时间复杂度。调用SQUARE-MATRIX-MULTIPLY-RECURSIVE计算两个n×nn×n矩阵乘法的运行时间用T(n)T(n)表示。对于n=1n = 1的初始情况,我们只需计算一次标量乘法,因此T(1)=Θ(1)T(1) = Θ(1)。当n>1n > 1时,根据上面的伪代码,T(n)T(n)包含88(n/2)×(n/2)(n/2)×(n/2)矩阵乘法的时间和44(n/2)×(n/2)(n/2)×(n/2)矩阵加法的时间,所以T(n)=8T(n/2)+Θ(n2)T(n) = 8T(n/2) + Θ(n^2),这里忽略了分解子矩阵的时间。于是,我们得到SQUARE-MATRIX-MULTIPLY-RECURSIVE的运行时间的递时式为
    在这里插入图片描述
  求解这个递归式得到T(n)=Θ(n3)T(n) = Θ(n^3)。可以看到,这个简单的分治法并没有带来渐近运行时间的提升。
  下面介绍Strassen算法。Strassen算法同样要将每个矩阵分解为44(n/2)×(n/2)(n/2)×(n/2)子矩阵。而与简单分治法不同,Strassen算法只需要递归为77次,而不是88次。下面直接给出Strassen算法的流程。
  (1) 将输入矩阵ABA、B以及输出矩阵CC各分解为44(n/2)×(n/2)(n/2)×(n/2)子矩阵。
  (2) 创建1010(n/2)×(n/2)(n/2)×(n/2)矩阵S1,S2,,S10S_1, S_2, …, S_{10},如下所示。由于需要进行1010(n/2)×(n/2)(n/2)×(n/2)矩阵的加减法,所以这一步花费Θ(n2)Θ(n^2)时间。
    在这里插入图片描述
  (3) 用步骤(1)分解得到的子矩阵和步骤(2)中创建的1010个矩阵,递归地计算77个矩阵乘积P1,P2,,P7P_1, P_2, …, P_7,如下所示。
    在这里插入图片描述
  (4) 利用矩阵P1,P2,,P7P_1, P_2, …, P_7进行加减运算,得到输出矩阵CC的子矩阵C11,C12,C21,C22C_{11}, C_{12}, C_{21}, C_{22},如下所示。这一步需要进行88(n/2)×(n/2)(n/2)×(n/2)矩阵的加减法,所以花费时为Θ(n2)Θ(n^2)
    在这里插入图片描述
  由于Strassen算法只需要递归为77次,因此它的运行时间的递归式为
    在这里插入图片描述
  求解这个递归式,可以得到Strassen算法的运行时间T(n)=Θ(nlg7)T(n) = Θ(n^{{\rm lg}7})

练习

4.2-1 使用Strassen算法计算如下矩阵乘法:
    在这里插入图片描述
  给出计算过程。
  
  (1) 分解输入矩阵
    在这里插入图片描述
  (2) 计算矩阵S1,S2,,S10S_1, S_2, …, S_{10}
    在这里插入图片描述
  (3) 计算矩阵P1,P2,,P7P_1, P_2, …, P_7
    在这里插入图片描述
  (4) 计算输出矩阵的44个子矩阵
    在这里插入图片描述
  最终结果为
    在这里插入图片描述
  
4.2-2 为Strassen算法编写伪代码。
  
  这里还是假设了矩阵的宽高nn22的幂。下面给出伪代码。
  在这里插入图片描述
  
4.2-3 如何修改Strassen算法,使之适应矩阵规模nn不是22的幂的情况?证明:算法的运行时间为Θ(nlg7)Θ(n_{{\rm lg}7})
  
  为了保证算法的通用性,需要考虑矩阵的宽高nn不为22的幂的情况。分两种情况讨论。
  (1) nn为偶数
  这种情况下n×nn×n矩阵可以直接分解为44(n/2)×(n/2)(n/2)×(n/2)的子矩阵,因此可以直接应用Strassen算法。为了计算矩阵乘法Cn×n=An×nBn×nC_{n×n} = A_{n×n}•B_{n×n},令m=n/2m = n/2,需要将矩阵分解为
    在这里插入图片描述
  这种情况下,矩阵乘法所花费的时间T(n)=7T(n/2)+Θ(n2)T(n) = 7T(n/2) + Θ(n^2)
  (2) nn为奇数
  这种情况不能直接应用Strassen算法。为了计算矩阵乘法Cn×n=An×nBn×nC_{n×n} = A_{n×n}•B_{n×n},令m=n1m = n−1,将矩阵做如下分解
    在这里插入图片描述
  如上所示,每个n×nn×n矩阵被分解为一个(n1)×(n1)(n−1)×(n−1)矩阵、一个(n1)×1(n−1)×1矩阵、一个1×(n1)1×(n−1)矩阵和一个1×11×1矩阵。相应地,矩阵乘法Cn×n=An×nBn×nC_{n×n} = A_{n×n}•B_{n×n}可以分解为下面44个式子。
    在这里插入图片描述
  上面4个式子包含了8个不同规模的矩阵乘法,下面逐个进行分析。
  1) A11m×mB11m×mA11_{m×m}•B11_{m×m}:由于m=n1m = n−1是偶数,所以这个矩阵乘法可以直接应用Strassen算法。
  这一矩阵乘法所花费的时间为T(n1)=7T((n1)/2)+Θ((n1)2)=7T(n/2)+Θ(n2)T(n-1)=7T((n-1)/2)+Θ((n-1)^2)=7T(⌊n/2⌋)+Θ(n^2)
  2) A12m×1B211×mA12_{m×1}•B21_{1×m}:采用朴素算法,需要做(n1)2(n−1)^2次乘法,因此运行时间为Θ(n2)Θ(n^2)
  3) A11m×mB12m×1A11_{m×m}•B12_{m×1}:采用朴素算法,需要做(n1)2(n−1)^2次乘法和(n1)(n2)(n−1)(n−2)次加法,因此运行时间也为Θ(n2)Θ(n^2)
  4) A12m×1B221×1A12_{m×1}•B22_{1×1}:采用朴素算法,需要做(n1)(n−1)次乘法,运行时间为Θ(n)Θ(n)
  5) A211×mB11m×mA21_{1×m}•B11_{m×m}:采用朴素算法,需要做(n1)2(n−1)^2次乘法,以及(n1)(n2)(n−1)(n−2)次加法,运行时间为Θ(n2)Θ(n^2)
  6) A221×1B211×mA22_{1×1}•B21_{1×m}:采用朴素算法,需要做(n1)(n−1)次乘法,运行时间为Θ(n)Θ(n)
  7) A211×mB12m×1A21_{1×m}•B12_{m×1}:采用朴素算法,需要做(n1)(n−1)次乘法,以及(n2)(n−2)次加法,运行时间为Θ(n)Θ(n)
  8) A221×1B221×1A22_{1×1}•B22_{1×1}:这仅仅是两个元素的相乘,只花费Θ(1)Θ(1)时间。
  根据以上分析,除去A11m×mB11m×mA11_{m×m}•B11_{m×m}之外,其他77个矩阵乘法加起来的运行时间为Θ(n2)Θ(n^2)。因此,当nn为奇数时,n×nn×n矩阵乘法的运行时间为
    T(n)=7T(n/2)+Θ(n2)T(n)=7T(⌊n/2⌋)+Θ(n^2)
  综合以上两种情况,无论nn为奇数还是偶数,矩阵乘法的运行时间都为T(n)=7T(n/2)+Θ(n2)T(n)=7T(⌊n/2⌋)+Θ(n^2)。忽略其中的⌊ ⌋符号,这与之前分析的Strassen算法的运行时间是一样的。
  下面给出具备通用性的Strassen算法的伪代码。
  在这里插入图片描述
  在这里插入图片描述
  
4.2-4 如果可以用kk次乘法操作(假定乘法的交换律不成立)完成两个3×33×3矩阵相乘,那么你可以在o(nlg7)o(n^{{\rm lg}7})时间内完成n×nn×n矩阵相乘,满足这一条件的最大kk是多少?此算法的运行时间是怎样的?
  
  仍然采用Strassen算法。我们现在分析该算法运行时间的递归式,不过在这里需要以T(3)T(3)作为边界条件,递归式如下所示。
    在这里插入图片描述
  如果我们画出递归树,该递归树一共有lg(n/3)lg(n/3)层。叶结点对应子问题T(3)T(3)。由于每层的结点数是上一层的77倍,因此第ii层包含7i7^i个结点。因此,叶结点一共有7lg(n/3)=7lgnlg3=7lgn/7lg3=nlg7/7lg37^{{\rm lg}(n/3)} =7^{{\rm lg}n-{\rm lg}3}=7^{{\rm lg}n}/7^{{\rm lg}3} =n^{{\rm lg}7}/7^{{\rm lg}3}个。因此所有叶结点的代价之和为(nlg7/7lg3)T(3)=k(nlg7/7lg3)(n^{{\rm lg}7}/7^{{\rm lg}3})•T(3)=k•(n^{{\rm lg}7}/7^{{\rm lg}3})
  如果要在o(nlg7)o(n_{{\rm lg}7})时间内完成n×nn×n矩阵相乘,那么必然有k(nlg7/7lg3)<nlg7k•(n^{{\rm lg}7}/7^{{\rm lg}3})<n^{{\rm lg}7},于是得到k<7lg321.85k<7^{{\rm lg}3}≈21.85。所以kk的最大值为2121
  
4.2-5 V.Pan发现一种方法,可以用132464132 464次乘法操作完成68×6868×68的矩阵相乘,发现另一种方法,可以用143640143 640次乘法操作完成70×7070×70的矩阵相乘,还发现一种方法,可以用155424155 424次乘法操作完成72×7272×72的矩阵相乘。当用于矩阵相乘的分治算法时,上述哪种方法会得到最佳的渐近运行时间?与Strassen算法相比,性能如何?
  
  对于采用分治法的矩阵乘法算法来说,其运行时间都为Θ(nd)Θ(n^d),其中dd为一个正常数。现在分析题目所给的33种方法,其渐近运行时间中的dd分别为多少。为方便起见,假设33种方法的运行时间分别为T1(n)=nd1T2(n)=nd2T_1(n)=n^{d_1},T_2(n)=n^{d_2}T3(n)=nd3T_3(n)=n^{d_3}
  用132464132 464次乘法操作完成68×6868×68的矩阵相乘,于是有
    T1(68)=68d1=132464T_1 (68)=68^{d_1}=132464
  得到d1=log681324642.795128d_1={\rm log}_{68}132464≈2.795128
  用143640143 640次乘法操作完成70×7070×70的矩阵相乘,于是有
    T2(70)=70d2=143640T_2 (70)=70^{d_2}=143640
  得到d2=log701436402.795122d_2={\rm log}_{70}143640≈2.795122
  用155424155 424次乘法操作完成72×7272×72的矩阵相乘,于是有
    T3(72)=72d3=155424T_3 (72)=72^{d_3}=155424
  得到d3=log721554242.795147d_3={\rm log}_{72}155424≈2.795147
  根据以上分析,第(2)种方法的渐近运行时间的指数d2d_2是最小的,所以第(2)种方法会得到最佳的渐近运行时间。
  Strassen算法的渐近运行时间为Θ(nlg7)Θ(nlg7)lg72.807355>d2{\rm lg}7 ≈ 2.807355 > d_2,因此上述第(2)种方法的性能是优于Strassen算法的。
  
4.2-6 用Strassen算法作为子过程来进行一个kn×nkn×n矩阵和一个n×knn×kn矩阵相乘,最快需要花费多长时间?对两个输入矩阵规模互换的情况,回答相同的问题。
  
  两个矩阵Akn×nA_{kn×n}Bn×knB_{n×kn}相乘,得到矩阵Ckn×knC_{kn×kn}。如果要利用Strassen算法,则需要将矩阵ABA、BCC按下面的方式分解
    在这里插入图片描述
  矩阵CC的任意一个子矩阵Cij=AiBjC_{ij} = A_i • B_j, 这是一个n×nn×n矩阵乘法,采用Strassen算法,运行时间为Θ(nlg7)Θ(n^{{\rm lg}7})。一共有k2k^2个这样的n×nn×n矩阵乘法,所以总的运行时间为Θ(k2nlg7)Θ(k^2•n^{{\rm lg}7})
  如果将输入矩阵的规模互换,即矩阵An×knA_{n×kn}Bkn×nB_{kn×n}相乘,得到矩阵Cn×nCn×n,那么需要将矩阵AABB按下面的方式分解
    在这里插入图片描述
  矩阵C=A1B1+A2B2++AkBkC = A1 • B1 + A2 • B2 + … + Ak • Bk。一共有kkn×nn×n矩阵乘法,并且还有(k1)(k−1)n×nn×n矩阵加法,所以总的运行时间为Θ(knlg7)Θ(k•n^{{\rm lg}7})
  
4.2-7 设计算法,仅使用三次实数乘法即可完成复数a+bia+bic+dic+di相乘。算法需接收abca、b、cdd为输入,分别生成实部acbdac−bd和虚部ad+bcad+bc
  
  借鉴Strassen算法的思想,该问题可以按以下步骤解决。
  (1) 计算P1P2P_1、P_2P3P_3
    P1=adP_1 = ad
    P2=bcP_2 = bc
    P3=(ab)(c+d)=acbd+adbcP_3 = (a – b)(c + d) = ac – bd + ad – bc
  (2) 计算实部和虚部
    实部:P3P1+P2=acbdP_3 – P_1 + P_2 = ac−bd
    虚部:P1+P2=ad+bcP_1 + P_2 = ad+bc
  该算法只需要33次乘法即可。
  
  本节代码链接:
  https://github.com/yangtzhou2012/Introduction_to_Algorithms_3rd/tree/master/Chapter04/Section_4.2

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章