問題:
設 和 是兩個 階矩陣,求它們的乘積矩陣C。這裏,假設 是 的冪次方。
一、問題分析(模型、算法設計和正確性證明等)
實驗要求使用分治法解決n階矩陣(n是2的冪次方)相乘問題,因爲n是2的冪次方,可以使用樸素分塊矩陣乘法或者 Strassen 法,這裏兩種都嘗試一下,順便連蠻力法也放進去。
二、複雜度分析
蠻力法僞代碼:
for i = 1 to n do:
for j = 1 to n do:
for k = 1 to n do:
C[i][j] = C[i][j] + A[i][k]・B[k][j]
顯然時間複雜度爲.
樸素分塊矩陣乘法僞代碼:
Divide_And_Conquer(int[][]A,int[][]B,int n){
int [][]C = new int[n][n]; //定義一個新矩陣存放結果
if n==1:
C11=A11*B11;
else Divide A, B and C as in 4 equation:
n /= 2;
C11=Divide_And_Conquer(A11,B11,n) + Divide_And_Conquer(A12,B21,n);
C22=Divide_And_Conquer(A11,B12,n) + Divide_And_Conquer(A12,B22,n);
C21=Divide_And_Conquer(A21,B11,n) + Divide_And_Conquer(A22,B21,n);
C22=Divide_And_Conquer(A21,B22,n) + Divide_And_Conquer(A22,B22,n);
return C;
}
因爲n/2 * n/2 的矩陣乘法進行了8次, n/2 * n/2的矩陣加法進行了4次所以複雜度爲:
Strassen 法僞代碼:
Strassen_DAC(int [][]A, int [][]B,int n){
int [][]C = new int[n][n]; //定義結果矩陣
if n==1:
C11 = A11*B11;
else Divide A, B, and C as in 4 equation:
n /= 2;
int [][]M1,M2,M3,M4,M5,M6,M7 = new int[n][n];
M1 = Strassen_DAC(A11, B12-B22, n);
M2 = Strassen_DAC(A11+A12, B22, n);
M3 = Strassen_DAC(A21+A22, B11, n);
M4 = Strassen_DAC(A22, B21-B11, n);
M5 = Strassen_DAC(A11+A12, B11+B12, n);
M6 = Strassen_DAC(A12-A22, B21+B22, n);
M7 = Strassen_DAC(A11-A21, B11+B12, n);
C11 = M5 + M4 - M2 + M6;
C12 = M1 + M2;
C21 = M3 + M4;
C22 = M5 + M1 -M3 -M7;
return C;
}
從僞代碼明顯可以看出,程序執行了7次n/2 * n/2的矩陣乘法,以及 18次n/2 *n/2的矩陣加減運算,所以複雜度爲:
三、程序實現和測試過程和結果(主要描述出現的問題和解決方法)
算法的思路倒是不難,難的是具體實現的時候,矩陣的分塊操作,容易繞不清楚。而且其中還有矩陣的加減運算,得出的結果最後還要把 合併成爲一個矩陣,這些在僞代碼裏都沒有給出來,但是複雜繁瑣容易出bug的正是這些細節。
實驗中使用Java語言編寫將三種方法放到同一個類中,一下爲類中的各個方法:
源碼:
package root;
/**
* @author 宇智波Akali
* 這是三種矩陣乘法
* @date 2020.3.18
*/
public class Try {
//創建一個隨機數構成的n*n矩陣
public static int[][] initializationMatrix(int n){
int[][] result = new int[n][n];//創建一個n*n矩陣
for(int i = 0;i < n;i++){
for(int j = 0;j < n;j++){
result[i][j] = (int)(Math.random()*10); //隨機生成1~10之間的數
}
}
return result;
}
//蠻力法求矩陣相乘
public static int[][] BruteForce(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = 0;
for(int k=0;k<n;k++){
result[i][j] += p[i][k]*q[k][j];
}
}
}
return result;
}
//分治法求矩陣相乘
public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];//創建一個n*n矩陣
//當n爲2時,用蠻力法求矩陣相乘,返回結果結果
if(n == 2){
result = BruteForce(p,q,n);
return result;
}
//當n大於3時,採用分治法,遞歸求最終結果
if(n > 2){
int m = n/2;
//將矩陣p分成四塊
int[][] p1 = QuarterMatrix(p,n,1);
int[][] p2 = QuarterMatrix(p,n,2);
int[][] p3 = QuarterMatrix(p,n,3);
int[][] p4 = QuarterMatrix(p,n,4);
//將矩陣q分成四塊
int[][] q1 = QuarterMatrix(q,n,1);
int[][] q2 = QuarterMatrix(q,n,2);
int[][] q3 = QuarterMatrix(q,n,3);
int[][] q4 = QuarterMatrix(q,n,4);
//將結果矩陣分成同等大小的四塊
int[][] result1 = QuarterMatrix(result,n,1);
int[][] result2 = QuarterMatrix(result,n,2);
int[][] result3 = QuarterMatrix(result,n,3);
int[][] result4 = QuarterMatrix(result,n,4);
//最關鍵的步驟,遞歸調用DivideAndConquer()函數,並用公式相加
result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);//y=ae+bg
result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);//s=af+bh
result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);//t=ce+dg
result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);//u=cf+dh
//合併,將四塊小矩陣合成整體
result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四個小矩陣合併成一個大矩陣
}
return result;
}
//strassen法
public static int[][] Strassen(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];//創建一個n*n矩陣
if( n == 2){
result = BruteForce(p,q,n);
return result;
}
int m = n/2;
//將矩陣p分成四塊
int[][] p1 = QuarterMatrix(p,n,1);
int[][] p2 = QuarterMatrix(p,n,2);
int[][] p3 = QuarterMatrix(p,n,3);
int[][] p4 = QuarterMatrix(p,n,4);
//將矩陣q分成四塊
int[][] q1 = QuarterMatrix(q,n,1);
int[][] q2 = QuarterMatrix(q,n,2);
int[][] q3 = QuarterMatrix(q,n,3);
int[][] q4 = QuarterMatrix(q,n,4);
int[][] m1 = DivideAndConquer(AddMatrix(p1,p4,m),AddMatrix(q1,q4,m),m);
int[][] m2 = Strassen(AddMatrix(p3,p4,m),q1,m);
int[][] m3 = Strassen(p1,ReduceMatrix(q2,q4,m),m);
int[][] m4 = Strassen(p4,ReduceMatrix(q3,q1,m),m);
int[][] m5 = Strassen(AddMatrix(p1,p2,m),q4,m);
int[][] m6 = Strassen(ReduceMatrix(p3,p1,m),AddMatrix(q1,q2,m),m);
int[][] m7 = Strassen(ReduceMatrix(p2,p4,m),AddMatrix(q3,q4,m),m);
//將結果矩陣分成同等大小的四塊
int[][] result1 = QuarterMatrix(result,n,1);
int[][] result2 = QuarterMatrix(result,n,2);
int[][] result3 = QuarterMatrix(result,n,3);
int[][] result4 = QuarterMatrix(result,n,4);
result1 = AddMatrix(ReduceMatrix(AddMatrix(m1,m4,m),m5,m),m7,m);
result2 = AddMatrix(m3,m5,m);
result3 = AddMatrix(m2,m4,m);
result4 = AddMatrix(AddMatrix(ReduceMatrix(m1,m2,m),m3,m),m6,m);
result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四個小矩陣合併成一個大矩陣
return result;
}
//獲取矩陣的四分之一,number用來確定返回哪一個四分之一
public static int[][] QuarterMatrix(int[][] p,int n,int number){
int rows = n/2; //行數減半
int cols = n/2; //列數減半
int[][] result = new int[rows][cols];
switch(number){
//左上
case 1 :
{
for(int i=0;i<rows;i++)
for(int j=0;j<cols;j++)
result[i][j] = p[i][j];
break;
}
//右上
case 2 :
{
for(int i=0;i<rows;i++)
for(int j=0;j<n-cols;j++)
result[i][j] = p[i][j+cols];
break;
}
//左下
case 3 :
{
for(int i=0;i<n-rows;i++)
for(int j=0;j<cols;j++)
result[i][j] = p[i+rows][j];
break;
}
//右下
case 4 :
{
for(int i=0;i<n-rows;i++)
for(int j=0;j<n-cols;j++)
result[i][j] = p[i+rows][j+cols];
break;
}
default:
break;
}
return result;
}
//把均分爲四分之一的矩陣,合成一個矩陣
public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
int[][] result = new int[2*n][2*n];
for(int i=0;i<2*n;i++){
for(int j=0;j<2*n;j++){
if(i<n){
if(j<n)
result[i][j] = a[i][j];
else
result[i][j] = b[i][j-n];
}else{
if(j<n)
result[i][j] = c[i-n][j];
else
result[i][j] = d[i-n][j-n];
}
}
}
return result;
}
//求兩個矩陣相加結果
public static int[][] AddMatrix(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = p[i][j]+q[i][j];
}
}
return result;
}
//求兩個矩陣相減結果
public static int[][] ReduceMatrix(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = p[i][j]-q[i][j];
}
}
return result;
}
//輸出矩陣的函數
public static void PrintfMatrix(int[][] matrix,int n){
for(int i=0;i<n;i++){
for(int j=0;j<n;j++)
System.out.printf("% 4d",matrix[i][j]);
System.out.println();
}
}
public static void main(String args[]){
int[][] p = initializationMatrix(8);
int[][] q = initializationMatrix(8);
//輸出生成的兩個矩陣
System.out.println("p:");
PrintfMatrix(p,8);
System.out.println();
System.out.println("q:");
PrintfMatrix(q,8);
//輸出分治法矩陣相乘後的結果
int[][] bru_result = BruteForce(p,q,8);//新建一個矩陣存放矩陣相乘結果
System.out.println();
System.out.println("\nA*B(蠻力法):");
PrintfMatrix(bru_result,8);
//輸出分治法矩陣相乘後的結果
int[][] dac_result = DivideAndConquer(p,q,8);//新建一個矩陣存放矩陣相乘結果
System.out.println();
System.out.println("A*B(分治法):");
PrintfMatrix(dac_result,8);
//輸出strassen法矩陣相乘後的結果
int[][] stra_result = Strassen(p,q,8);//新建一個矩陣存放矩陣相乘結果
System.out.println("\nA*B(strassen法):");
PrintfMatrix(stra_result,8);
}
}
運行結果: