算法題03:分治與遞歸:三種矩陣乘法(蠻力法,樸素分治法、Strassen法)

問題:
AABB 是兩個 n×nn\times n 階矩陣,求它們的乘積矩陣C。這裏,假設 nn22 的冪次方。

一、問題分析(模型、算法設計和正確性證明等)

​ 實驗要求使用分治法解決n階矩陣(n是2的冪次方)相乘問題,因爲n是2的冪次方,可以使用樸素分塊矩陣乘法或者 Strassen 法,這裏兩種都嘗試一下,順便連蠻力法也放進去。

二、複雜度分析

蠻力法僞代碼

for i = 1 to n do:
	for j = 1 to n do:
		for k = 1 to n do:
			C[i][j] = C[i][j] + A[i][k]・B[k][j]

顯然時間複雜度爲O(n3)O(n^3).

樸素分塊矩陣乘法僞代碼

Divide_And_Conquer(int[][]A,int[][]B,int n){

    int [][]C = new int[n][n];		//定義一個新矩陣存放結果
    if n==1:
        C11=A11*B11;
    else Divide A, B and C as in 4 equation:
    	n /= 2;
        C11=Divide_And_Conquer(A11,B11,n) + Divide_And_Conquer(A12,B21,n);
        C22=Divide_And_Conquer(A11,B12,n) + Divide_And_Conquer(A12,B22,n);
        C21=Divide_And_Conquer(A21,B11,n) + Divide_And_Conquer(A22,B21,n);
        C22=Divide_And_Conquer(A21,B22,n) + Divide_And_Conquer(A22,B22,n);

    return C;
}

因爲n/2 * n/2 的矩陣乘法進行了8次, n/2 * n/2的矩陣加法進行了4次所以複雜度爲:

T(n)={1n=18T(n/2)+4(n2)2n>1O(n3)T(n)=\left\{\begin{array}{ll} 1 & n=1 \\ 8 T(n / 2)+4\left(\frac{n}{2}\right)^{2} & n>1 \end{array} \Longrightarrow O\left(n^{3}\right)\right.

Strassen 法僞代碼:

Strassen_DAC(int [][]A, int [][]B,int n){
    int [][]C = new int[n][n];		//定義結果矩陣
    if n==1:
    	C11 = A11*B11;
    else Divide A, B, and C as in 4 equation:
    	n /= 2;
    	int [][]M1,M2,M3,M4,M5,M6,M7 = new int[n][n];
    	M1 = Strassen_DAC(A11, B12-B22, n);
		M2 = Strassen_DAC(A11+A12, B22, n);
    	M3 = Strassen_DAC(A21+A22, B11, n);
    	M4 = Strassen_DAC(A22, B21-B11, n);
    	M5 = Strassen_DAC(A11+A12, B11+B12, n);
    	M6 = Strassen_DAC(A12-A22, B21+B22, n);
    	M7 = Strassen_DAC(A11-A21, B11+B12, n);
    	C11 = M5 + M4 - M2 + M6;
    	C12 = M1 + M2;
    	C21 = M3 + M4;
    	C22 = M5 + M1 -M3 -M7;
    return C;
}

從僞代碼明顯可以看出,程序執行了7次n/2 * n/2的矩陣乘法,以及 18次n/2 *n/2的矩陣加減運算,所以複雜度爲:

T(n)={1n=17T(n/2)+18(n2)2n>1O(nlog3)T(n)=\left\{\begin{array}{ll}1 & n=1 \\ 7 T(n / 2)+18\left(\frac{n}{2}\right)^{2} & n>1\end{array} \Longrightarrow O\left(n^{\log 3}\right)\right.

三、程序實現和測試過程和結果(主要描述出現的問題和解決方法)

​ 算法的思路倒是不難,難的是具體實現的時候,矩陣的分塊操作,容易繞不清楚。而且其中還有矩陣的加減運算,得出的結果最後還要把C11,C12,C21,C22C11, C12, C21, C22 合併成爲一個矩陣CC,這些在僞代碼裏都沒有給出來,但是複雜繁瑣容易出bug的正是這些細節。

​ 實驗中使用Java語言編寫將三種方法放到同一個類中,一下爲類中的各個方法:

在這裏插入圖片描述

源碼:

package root;

/**
 * @author 宇智波Akali
 * 這是三種矩陣乘法
 * @date 2020.3.18
 */
public class Try {
	//創建一個隨機數構成的n*n矩陣
	public static int[][] initializationMatrix(int n){
		int[][] result = new int[n][n];//創建一個n*n矩陣
		for(int i = 0;i < n;i++){
			for(int j = 0;j < n;j++){
				result[i][j] = (int)(Math.random()*10); //隨機生成1~10之間的數
			}
		} 
		return result; 
	}

	//蠻力法求矩陣相乘
	public static int[][] BruteForce(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = 0;
				for(int k=0;k<n;k++){
					result[i][j] += p[i][k]*q[k][j];
				}
			}
		}  
		return result;
	}

	//分治法求矩陣相乘
	public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];//創建一個n*n矩陣
		//當n爲2時,用蠻力法求矩陣相乘,返回結果結果
		if(n == 2){
			result = BruteForce(p,q,n); 
			return result;
		}
	 
		//當n大於3時,採用分治法,遞歸求最終結果
		if(n > 2){
			int m = n/2;
			
			//將矩陣p分成四塊
			int[][] p1 = QuarterMatrix(p,n,1);
			int[][] p2 = QuarterMatrix(p,n,2);
			int[][] p3 = QuarterMatrix(p,n,3);
			int[][] p4 = QuarterMatrix(p,n,4);
			
			//將矩陣q分成四塊
			int[][] q1 = QuarterMatrix(q,n,1);
			int[][] q2 = QuarterMatrix(q,n,2);
			int[][] q3 = QuarterMatrix(q,n,3);
			int[][] q4 = QuarterMatrix(q,n,4);
			
			//將結果矩陣分成同等大小的四塊
			int[][] result1 = QuarterMatrix(result,n,1);
			int[][] result2 = QuarterMatrix(result,n,2);
			int[][] result3 = QuarterMatrix(result,n,3);
			int[][] result4 = QuarterMatrix(result,n,4);
		
			//最關鍵的步驟,遞歸調用DivideAndConquer()函數,並用公式相加
			result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);//y=ae+bg
			result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);//s=af+bh
			result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);//t=ce+dg
			result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);//u=cf+dh
			
			//合併,將四塊小矩陣合成整體
			result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四個小矩陣合併成一個大矩陣
		}
		return result;
	}
	
	//strassen法
	public static int[][] Strassen(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];//創建一個n*n矩陣
		if( n == 2){
			result = BruteForce(p,q,n);
			return result;
		}
		int m = n/2;
		
		//將矩陣p分成四塊
		int[][] p1 = QuarterMatrix(p,n,1);
		int[][] p2 = QuarterMatrix(p,n,2);
		int[][] p3 = QuarterMatrix(p,n,3);
		int[][] p4 = QuarterMatrix(p,n,4);
		
		//將矩陣q分成四塊
		int[][] q1 = QuarterMatrix(q,n,1);
		int[][] q2 = QuarterMatrix(q,n,2);
		int[][] q3 = QuarterMatrix(q,n,3);
		int[][] q4 = QuarterMatrix(q,n,4);
				int[][] m1 = DivideAndConquer(AddMatrix(p1,p4,m),AddMatrix(q1,q4,m),m);
		int[][] m2 = Strassen(AddMatrix(p3,p4,m),q1,m);
		int[][] m3 = Strassen(p1,ReduceMatrix(q2,q4,m),m);
		int[][] m4 = Strassen(p4,ReduceMatrix(q3,q1,m),m);
		int[][] m5 = Strassen(AddMatrix(p1,p2,m),q4,m);
		int[][] m6 = Strassen(ReduceMatrix(p3,p1,m),AddMatrix(q1,q2,m),m);
		int[][] m7 = Strassen(ReduceMatrix(p2,p4,m),AddMatrix(q3,q4,m),m);
		
		//將結果矩陣分成同等大小的四塊
		int[][] result1 = QuarterMatrix(result,n,1);
		int[][] result2 = QuarterMatrix(result,n,2);
		int[][] result3 = QuarterMatrix(result,n,3);
		int[][] result4 = QuarterMatrix(result,n,4);
	
		result1 = AddMatrix(ReduceMatrix(AddMatrix(m1,m4,m),m5,m),m7,m);
		result2 = AddMatrix(m3,m5,m);
		result3 = AddMatrix(m2,m4,m);
		result4 = AddMatrix(AddMatrix(ReduceMatrix(m1,m2,m),m3,m),m6,m);
		
		result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四個小矩陣合併成一個大矩陣
		
		return result;
	}
	
	
	
	
	//獲取矩陣的四分之一,number用來確定返回哪一個四分之一
	public static int[][] QuarterMatrix(int[][] p,int n,int number){
		int rows = n/2;  //行數減半
		int cols = n/2;  //列數減半
		int[][] result = new int[rows][cols];
		switch(number){
		//左上
		case 1 :
		{
			for(int i=0;i<rows;i++)
				for(int j=0;j<cols;j++)
					result[i][j] = p[i][j];
			break;
		}
		//右上
		case 2 :
		{
			for(int i=0;i<rows;i++)
				for(int j=0;j<n-cols;j++)
					result[i][j] = p[i][j+cols];
			break;
		}
		//左下
		case 3 :
		{
			for(int i=0;i<n-rows;i++)
				for(int j=0;j<cols;j++)
					result[i][j] = p[i+rows][j];
			break;
		}
		//右下
		case 4 :
		{
			for(int i=0;i<n-rows;i++)
				for(int j=0;j<n-cols;j++)
					result[i][j] = p[i+rows][j+cols];
			break;
		}
		default:
			break;
		}
	
		return result;
	}

	//把均分爲四分之一的矩陣,合成一個矩陣
	public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
		int[][] result = new int[2*n][2*n];
		for(int i=0;i<2*n;i++){
			for(int j=0;j<2*n;j++){
				if(i<n){
					if(j<n)
						result[i][j] = a[i][j];
					else
						result[i][j] = b[i][j-n];
				}else{
					if(j<n)
						result[i][j] = c[i-n][j];
					else
						result[i][j] = d[i-n][j-n];
				}
			}
		}
		return result;
	}


	//求兩個矩陣相加結果
	public static int[][] AddMatrix(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = p[i][j]+q[i][j];
			}
		}
		return result;
	}
	
	//求兩個矩陣相減結果
	public static int[][] ReduceMatrix(int[][] p,int[][] q,int n){
		int[][] result = new int[n][n];
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				result[i][j] = p[i][j]-q[i][j];
			}
		}
		return result;
	}
	
	//輸出矩陣的函數
	public static void PrintfMatrix(int[][] matrix,int n){
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++)
				System.out.printf("% 4d",matrix[i][j]);
			System.out.println();
		}
	
	}

	public static void main(String args[]){
		int[][] p = initializationMatrix(8);
		int[][] q = initializationMatrix(8);
		//輸出生成的兩個矩陣
		System.out.println("p:");
		PrintfMatrix(p,8);
		System.out.println();
		System.out.println("q:");
		PrintfMatrix(q,8);
 
		//輸出分治法矩陣相乘後的結果
		int[][] bru_result = BruteForce(p,q,8);//新建一個矩陣存放矩陣相乘結果
		System.out.println();
		System.out.println("\nA*B(蠻力法):");
		PrintfMatrix(bru_result,8);
		
		//輸出分治法矩陣相乘後的結果
		int[][] dac_result = DivideAndConquer(p,q,8);//新建一個矩陣存放矩陣相乘結果
		System.out.println();
		System.out.println("A*B(分治法):");
		PrintfMatrix(dac_result,8);
		
		//輸出strassen法矩陣相乘後的結果
		int[][] stra_result = Strassen(p,q,8);//新建一個矩陣存放矩陣相乘結果
		System.out.println("\nA*B(strassen法):");
		PrintfMatrix(stra_result,8);
		
	}
 
}

運行結果:
運行結果

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章