《算法導論》學習心得(二)—— 矩陣乘法之Strassen算法

個人blog遷移到www.forwell.me           

    在開始之前,請點擊下載源碼。提起矩陣乘法,你也許會說不就是三次循環就解決問題了嗎,這有什麼好說的。是啊,三個循環確實是完事了,時間效率是O(n^3),這是我們上第一節線代老師就清清楚楚的告訴我們的,但是他沒有告訴你還有比這更好的矩陣乘法,時間效率爲O(n^{log_2 7}) = O(n^{2.807}),也許你覺得這沒有什麼,就提高了0.2幾,沒啥,但是你想過沒有,當N=100,10000的時候呢,Strassen算法和傳統方法又有多少差別呢,讓我們來看一下Strassen算法和傳統方法的效率對比圖:


通過圖我們會發現Strassen算法在N超過50的時候就開始表現出明顯的優勢,然而現實生產中矩陣都是上百階的,那Strassen算法更是佔有絕對的優勢,所以我們今天就很有必要學習Strassen算法,那下面就開始進入正題。

Strassen算法

1969年,德國的一位數學家Strassen證明O(N^3)的解法並不是矩陣乘法的最優算法,他做了一系列工作使得最終的時間複雜度降低到了O(n^2.80)。那他是怎麼做到的呢?對於矩陣乘法  C =  A × B,通常的做法是將矩陣進行分塊相乘,如下圖所示:

從上圖可以看出這種分塊相乘總共用了8次乘法,要改進算法計算時間的複雜度,必須減少乘法運算次數。按分治法的思想,Strassen提出一種新的方法,用7次乘法完成2階矩陣的乘法,算法如下:

M1 = A11(B12 - B12)
M2 = (A11 + A12)B22
M3 = (A21 + A22)B11
M4 = A22(B21 - B11)
M5 = (A11 + A22)(B11 + B22)
M6 = (A12 - A22)(B21 + B22)
M7 = (A11 - A21)(B11 + B12)
完成了7次乘法,再做如下加法:
C11 = M5 + M4 - M2 + M6
C12 = M1 + M2
C21 = M3 + M4
C22 = M5 + M1 - M3 - M7
全部計算使用了7次乘法和18次加減法,計算時間降低到O(nE2.81)。計算複雜性得到較大改進。

具體代碼實現如下:

//Strassen二階矩陣的乘法
	static int[][] twostrassenMatrixMultiply(int [][]x,int [][]y) //階數爲2的矩陣乘法    
	{   
		int matrixXColumnLength = x[0].length;
		int matrixYRowLength = x.length;//獲取矩陣的行長度
		if(matrixXColumnLength!=matrixYRowLength)
		{
			throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,無法進行乘法計算!");
		}
		int p1,p2,p3,p4,p5,p6,p7;//這些都是按照算法定義進行的
		int [][]result = new int[2][2];
		p1=(y[0][1] - y[1][1]) * x[0][0];   
		p2=y[1][1] * (x[0][0] + x[0][1]);   
		p3=(x[1][0] + x[1][1]) * y[0][0];   
		p4=x[1][1] * (y[1][0] - y[0][0]);   
		p5=(x[0][0] + x[1][1]) * (y[0][0]+y[1][1]);   
		p6=(x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);   
		p7=(x[0][0] - x[1][0]) * (y[0][0]+y[0][1]);   
		result[0][0] = p5 + p4 - p2 + p6;   
		result[0][1] = p1 + p2;   
		result[1][0] = p3 + p4;
		result[1][1] = p5 + p1 - p3 - p7;
		return result;   
	}   
整個計算過程爲:

static int[][] strassenMatrixMultiply(int [][]x,int [][]y) //矩陣乘法方法    
	{   
		if(x.length==2)   
		{   
			return twostrassenMatrixMultiply(x,y);
		}   
		else   
		{   
			int matrixLength = x.length/2;
			int[][] a11,a12,a21,a22;
			a11 = new int[matrixLength][matrixLength];
			a12 = new int[matrixLength][matrixLength];
			a21 = new int[matrixLength][matrixLength];
			a22 = new int[matrixLength][matrixLength];
			int[][] b11,b12,b21,b22;
			b11 = new int[matrixLength][matrixLength];
			b12 = new int[matrixLength][matrixLength];
			b21 = new int[matrixLength][matrixLength];
			b22 = new int[matrixLength][matrixLength];
			int[][] c11,c12,c21,c22,c;   
			c11 = new int[matrixLength][matrixLength];
			c12 = new int[matrixLength][matrixLength];
			c21 = new int[matrixLength][matrixLength];
			c22 = new int[matrixLength][matrixLength];
			c = new int[2*matrixLength][2*matrixLength];
			int[][] m1,m2,m3,m4,m5,m6,m7;
			divide(x,a11,a12,a21,a22); //拆分A、B、C矩陣    
			divide(y,b11,b12,b21,b22);   
			divide(c,c11,c12,c21,c22);
			m1=strassenMatrixMultiply(a11,matrixMinus(b12,b22));   
			m2=strassenMatrixMultiply(matrixPlus(a11,a12),b22);
			m3=strassenMatrixMultiply(matrixPlus(a21,a22),b11);
			m4=strassenMatrixMultiply(a22,matrixMinus(b21,b11));   
			m5=strassenMatrixMultiply(matrixPlus(a11,a22),matrixPlus(b11,b22));   
			m6=strassenMatrixMultiply(matrixMinus(a12,a22),matrixPlus(b21,b22));   
			m7=strassenMatrixMultiply(matrixMinus(a11,a21),matrixPlus(b11,b12));   
			c11=matrixPlus(matrixMinus(matrixPlus(m5,m4),m2),m6);   
			c12=matrixPlus(m1,m2);   
			c21=matrixPlus(m3,m4);   
			c22=matrixMinus(matrixMinus(matrixPlus(m5,m1),m3),m7);
			c=merge(c11,c12,c21,c22); //合併C矩陣    
			return c;   
		}    
	}
上面就是整個算法的實現過程,歡迎大家前來討論[email protected]


完整代碼:

package com.tangbo;

import java.util.Random;
import java.util.Scanner;
/*
 * @Author:唐波
 * Strassen矩陣乘法
 * 2014.10.31
 * 程序對比了傳統方法和Strassen算法計算的結果是否相等
 * 算法來源:1969年,德國的一位數學家Strassen證明O(N^3)的解法並不是矩陣乘法的最優算法,他做了一系列工作使得最終的時間複雜度降低到了O(n^2.80)
 */
public class SquareMatrixMultiply {
	static Random random = new Random();
	static Scanner in;
	public static void main(String[] args) 
	{ 
		int matrixLength=0;
		in = new Scanner(System.in);   
		System.out.print("輸入矩陣的階數: ");   
		matrixLength = in.nextInt();
		if(isEven(matrixLength)==0)
		{
			int [][]x=productMatrix(matrixLength);
			int [][]y=productMatrix(matrixLength);
			System.out.println("x矩陣:");
			printMatrix(x);
			System.out.println("y矩陣:");
			printMatrix(y);
			int [][]strassenResult =strassenMatrixMultiply(x,y);//Strassen計算結果
			System.out.println("Strassen計算結果:");
			printMatrix(strassenResult);
			int [][] forceResult = forceMatrixMultiply(x, y);//傳統方法計算結果
			System.out.println("傳統計算結果:");
			printMatrix(forceResult);
			boolean isEqual = isEqual(forceResult, strassenResult);//比較兩種計算結果是否相等
			if(isEqual)
			{
				System.out.println("兩個計算結果相等!");
			}else
			{
				System.err.println("兩個計算結果不相等!");
				System.exit(0);//程序退出
			}
		}else
		{
			System.out.println("矩陣不是2^k方陣,無法計算!");
		}
	}
	static boolean isEqual(int [][]x,int [][]y)//遍歷判斷兩個矩陣是否相等
	{
		boolean equal=true;
		for(int i =0;i<x.length;i++)
		{
			for(int j=0;j<x[0].length;j++)
			{
				if(x[i][j]!=y[i][j])
				{
					equal=false;
				}
			}
		}
		return equal;
	}
	static int isEven(int n)//判斷輸入矩陣階數是否爲2^k
	{   
		int a = 1,temp=n;   
		while(temp%2==0)   
		{   
			if(temp%2==0)    
				temp/=2; 
		}  
		if(temp==1)    
			a=0;   
		return a;
	}   
	static int[][] productMatrix(int matrixLength)//自動生成矩陣
	{
		int matrix[][] = new int[matrixLength][matrixLength];
		//初始化矩陣
		for(int i=0;i<matrixLength;i++)
		{
			for(int j=0;j<matrixLength;j++)
			{
				matrix[i][j] = (int)(Math.random()*10);
			}
		}
		System.out.println();
		return matrix;
	}
	static void printMatrix(int matrix[][])//矩陣打印函數
	{
		int matrixRowLength = matrix.length;//獲取矩陣的行數
		int matrixColumnLength = matrix[0].length;//獲取矩陣的列數
		for(int i=0;i<matrixRowLength;i++)
		{
			for(int j=0;j<matrixColumnLength;j++)
			{
				System.out.print(matrix[i][j]+" ");
			}
			System.out.println();
		}
	}
	static int[][] matrixPlus(int[][] x,int[][] y) //矩陣加法    
	{   
		int matrixXRowLength = x.length;//獲取矩陣的行長度
		int matrixXColumnLength = x[0].length;
		int matrixYRowLength = x.length;//獲取矩陣的行長度
		int matrixYColumnLength = x[0].length;
		if(matrixXColumnLength!=matrixYColumnLength || matrixXRowLength!=matrixYRowLength)//判斷矩陣是否同型
		{
			throw new RuntimeException("矩陣不同型,無法進行加法計算!");	
		}
		int[][] result = new int[matrixXRowLength][matrixXColumnLength];
		for(int i=0;i<matrixXColumnLength;i++)
		{
			for (int j = 0; j < matrixXColumnLength; j++) 
			{
				result[i][j] = x[i][j]+y[i][j]; 
			}
		}
		return result;
	}   

	static int[][] matrixMinus(int[][] x,int[][] y)//矩陣減法
	{
		int matrixXRowLength = x.length;//獲取矩陣的行長度
		int matrixXColumnLength = x[0].length;
		int matrixYRowLength = x.length;//獲取矩陣的行長度
		int matrixYColumnLength = x[0].length;
		if(matrixXColumnLength!=matrixYColumnLength || matrixXRowLength!=matrixYRowLength)
		{
			throw new RuntimeException("矩陣不同型,無法進行減法計算!");	
		}
		int[][] result = new int[matrixXRowLength][matrixXColumnLength];
		for(int i=0;i<matrixXColumnLength;i++)
		{
			for (int j = 0; j < matrixXColumnLength; j++) 
			{
				result[i][j] = x[i][j]-y[i][j]; 
			}
		}
		return result;
	}

	//Strassen二階矩陣的乘法
	static int[][] twostrassenMatrixMultiply(int [][]x,int [][]y) //階數爲2的矩陣乘法    
	{   
		int matrixXColumnLength = x[0].length;
		int matrixYRowLength = x.length;//獲取矩陣的行長度
		if(matrixXColumnLength!=matrixYRowLength)
		{
			throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,無法進行乘法計算!");
		}
		int p1,p2,p3,p4,p5,p6,p7;//這些都是按照算法定義進行的
		int [][]result = new int[2][2];
		p1=(y[0][1] - y[1][1]) * x[0][0];   
		p2=y[1][1] * (x[0][0] + x[0][1]);   
		p3=(x[1][0] + x[1][1]) * y[0][0];   
		p4=x[1][1] * (y[1][0] - y[0][0]);   
		p5=(x[0][0] + x[1][1]) * (y[0][0]+y[1][1]);   
		p6=(x[0][1] - x[1][1]) * (y[1][0]+y[1][1]);   
		p7=(x[0][0] - x[1][0]) * (y[0][0]+y[0][1]);   
		result[0][0] = p5 + p4 - p2 + p6;   
		result[0][1] = p1 + p2;   
		result[1][0] = p3 + p4;
		result[1][1] = p5 + p1 - p3 - p7;
		return result;   
	}   
	static void divide(int[][] a,int[][] a11,int[][] a12,int[][] a21,int[][] a22)//分解矩陣
	{   
		int matrixLength = a.length/2;
		for(int i=0;i<matrixLength;i++)   
			for(int j=0;j<matrixLength;j++)   
			{
				a11[i][j]=a[i][j];
				a12[i][j]=a[i][j+matrixLength];   
				a21[i][j]=a[i+matrixLength][j];   
				a22[i][j]=a[i+matrixLength][j+matrixLength];   
			}   
	}

	static int[][] merge(int [][]a11,int [][]a12,int [][]a21,int [][]a22)//合併矩陣    
	{   
		int n=a11.length;
		int [][] result = new int[2*n][2*n];
		for(int i=0;i<n;i++)
		{
			for(int j=0;j<n;j++)
			{
				result[i][j]=a11[i][j];   
				result[i][j+n]=a12[i][j];   
				result[i+n][j]=a21[i][j];   
				result[i+n][j+n]=a22[i][j];   
			}
		}
		return result;
	}
	static int[][] strassenMatrixMultiply(int [][]x,int [][]y) //矩陣乘法方法    
	{   
		if(x.length==2)   
		{   
			return twostrassenMatrixMultiply(x,y);
		}   
		else   
		{   
			int matrixLength = x.length/2;
			int[][] a11,a12,a21,a22;
			a11 = new int[matrixLength][matrixLength];
			a12 = new int[matrixLength][matrixLength];
			a21 = new int[matrixLength][matrixLength];
			a22 = new int[matrixLength][matrixLength];
			int[][] b11,b12,b21,b22;
			b11 = new int[matrixLength][matrixLength];
			b12 = new int[matrixLength][matrixLength];
			b21 = new int[matrixLength][matrixLength];
			b22 = new int[matrixLength][matrixLength];
			int[][] c11,c12,c21,c22,c;   
			c11 = new int[matrixLength][matrixLength];
			c12 = new int[matrixLength][matrixLength];
			c21 = new int[matrixLength][matrixLength];
			c22 = new int[matrixLength][matrixLength];
			c = new int[2*matrixLength][2*matrixLength];
			int[][] m1,m2,m3,m4,m5,m6,m7;
			divide(x,a11,a12,a21,a22); //拆分A、B、C矩陣    
			divide(y,b11,b12,b21,b22);   
			divide(c,c11,c12,c21,c22);
			m1=strassenMatrixMultiply(a11,matrixMinus(b12,b22));   
			m2=strassenMatrixMultiply(matrixPlus(a11,a12),b22);
			m3=strassenMatrixMultiply(matrixPlus(a21,a22),b11);
			m4=strassenMatrixMultiply(a22,matrixMinus(b21,b11));   
			m5=strassenMatrixMultiply(matrixPlus(a11,a22),matrixPlus(b11,b22));   
			m6=strassenMatrixMultiply(matrixMinus(a12,a22),matrixPlus(b21,b22));   
			m7=strassenMatrixMultiply(matrixMinus(a11,a21),matrixPlus(b11,b12));   
			c11=matrixPlus(matrixMinus(matrixPlus(m5,m4),m2),m6);   
			c12=matrixPlus(m1,m2);   
			c21=matrixPlus(m3,m4);   
			c22=matrixMinus(matrixMinus(matrixPlus(m5,m1),m3),m7);
			c=merge(c11,c12,c21,c22); //合併C矩陣    
			return c;   
		}    
	}
	static int[][] forceMatrixMultiply(int [][]x,int [][]y)
	{
		int matrixXRowLength = x.length;//獲取矩陣的行長度
		int matrixXColumnLength = x[0].length;
		int matrixYRowLength = x.length;//獲取矩陣的行長度
		int matrixYColumnLength = x[0].length;
		if(matrixXColumnLength!=matrixYRowLength)
		{
			throw new RuntimeException("matrixXColumnLength!=matrixYRowLength,無法進行乘法計算!");
		}
		int [][] result = new int[matrixXRowLength][matrixYColumnLength];
		for(int i=0;i<matrixXRowLength;i++)
		{
			for(int j=0;j<matrixYColumnLength;j++)
			{
				result[i][j]=0;
				for(int k=0;k<matrixYRowLength;k++)
				{
					result[i][j] = result[i][j]+x[i][k]*y[k][j];
				}
			}
		}
		return result;
	}

}





發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章