矩陣相乘(分治法)

一個簡單的分治算法求矩陣相乘
C=A * B ,假設三個矩陣均爲n×n,n爲2的冪。可以對其分解爲4個n/2×n/2的子矩陣分別遞歸求解:
1
2

遞歸分治算法:
3

算法中一個重要的細節就是在分塊的時候,採用的是下標的方式。

#include <stdio.h>
#include <stdlib.h>
#define ROW 16       //指定 行數
#define COL 16       //指定 列數 

int a[ROW][COL],b[ROW][COL];  //矩陣a 和 矩陣b
int **c;                      // c = a * b 

//保存一個矩陣的第一個元素的位置,即左上角元素的下標
//如果加上一個長度就可以知道整個矩陣了
typedef struct {   //這裏沒有指定一個矩陣的長度,在分塊時應該加入長度,否則不知道子塊矩陣的大小
    int str,stc;    //str行下標  ; strc列下標
}subarr;

// 兩矩陣arr、brr相加減 保存在temp中
void operate(int **arr,int **brr,subarr te,char op,int **temp,int len);

//分治法 求矩陣相乘 ,sa,sb分別爲矩陣a,b參加運算的首元素
int ** square_recursive(subarr sa,subarr sb,subarr sc,int len){
    int n=len;
    int **temp;
    int i;
    // 申請一個臨時矩陣,用於保存a*b 
    temp=(int**)malloc(sizeof(int *)*n);
    for ( i=0;i<n;++i){
        temp[i]=(int *)malloc(sizeof(int)*n);
    }
    // 長度爲1 則直接相乘
    if (n==1)
    {
        temp[0][0]=a[sa.str][sa.stc]*b[sb.str][sb.stc];
    }else{
         // 這裏都是對下標進行初始化
         // sa,sb,sc代表輸入矩陣A,B,temp參加運算的首元素下標,因爲進行分塊後只進行特定子塊的運算
         //標號1,2,3,4 分別代表第一、二、三、四個子塊
        subarr sa1,sb1, sc1;
        subarr sa2,sb2, sc2;
        subarr sa3, sb3,sc3;
        subarr sa4, sb4, sc4;
        // 矩陣A 進行分塊後的各個子塊下標
        sa1.str=sa.str;
        sa1.stc=sa.stc;
        sa2.str=sa.str;
        sa2.stc=sa.stc+n/2;
        sa3.stc=sa.stc;
        sa3.str=sa.str+n/2;
        sa4.str=sa.str+n/2;
        sa4.stc=sa.stc+n/2;
        // 矩陣B 進行分塊後的各個子塊下標
        sb1.str=sb.str;
        sb1.stc=sb.stc;
        sb2.str=sb.str;
        sb2.stc=sb.stc+n/2;
        sb3.stc=sb.stc;
        sb3.str=sb.str+n/2;
        sb4.str=sb.str+n/2;
        sb4.stc=sb.stc+n/2;
        // 矩陣temp 進行分塊後的各個子塊下標
        sc1.str=sc1.stc=0;
        sc2.str=0;
        sc2.stc=n/2;
        sc3.stc=0;
        sc3.str=n/2;
        sc4.str=n/2;
        sc4.stc=n/2;
// 將矩陣分爲四塊  分別求解。採用下標的方式進行分塊,可以省去複製矩陣所產生的時間
// 若要複製矩陣則會產生 O(n*n)的時間複雜度
    operate(square_recursive(sa1,sb1,sc1,n/2),square_recursive(sa2,sb3,sc1,n/2),sc1,'+',temp,n/2);

        operate(square_recursive(sa1,sb2,sc2,n/2),square_recursive(sa2,sb4,sc2,n/2),sc2,'+',temp,n/2);

        operate(square_recursive(sa3,sb1,sc3,n/2),square_recursive(sa4,sb3,sc3,n/2),sc3,'+',temp,n/2);

        operate(square_recursive(sa3,sb2,sc4,n/2),square_recursive(sa4,sb4,sc4,n/2),sc4,'+',temp,n/2);


    }
    return temp;

}
//  temp矩陣的te位置(四個子塊中的一個)=arr+brr
// len表示arr,brr參加運算的長度
// op是運算符 ‘+’ 
void operate(int **arr,int **brr,subarr te,char op,int **temp,int len){
    int i,j;
    switch(op){
        case '+':
            for (i=0;i<len;++i){
                for (j = 0; j < len; ++j)
                {
                    temp[te.str+i][te.stc+j]=arr[i][j]+brr[i][j];
                }
            }
            break;
        case '-':
            for (i=0;i<len;++i){
                for (j = 0; j < len; ++j)
                {
                    temp[te.str+i][te.stc+j]=arr[i][j]-brr[i][j];
                }
            }
            break;
    }
}
//爲矩陣初始化 即賦值
void createarr(int temp[][COL]){
    int i,j;
    for (i = 0; i < ROW; ++i)
    {
        for (j = 0; j < COL; ++j)
        {
            temp[i][j]=(int)rand()%5;

        }

    }

}
// 打印C矩陣
void print(){
    int i,j;
    printf("\n====================================\n");
    for (i = 0; i < ROW; ++i)
    {
        for (j = 0; j < COL; ++j)
        {
            printf("%d\t", c[i][j]);
        }
        printf("\n");
    }
    printf("===================================\n");
}
// 打印矩陣
void printarray(int a[ROW][COL]){
    int i,j;
    printf("-----------------------\n");
    for (i = 0; i < ROW; ++i)
    {
        for (j = 0; j < COL; ++j)
        {
            printf("%d \t", a[i][j]);
        }
        printf("\n");
    }
    printf("-----------------------\n");
}


int main(){
    int i,j;
    subarr sa,sb,sc;
    int len;
    //初始化各個下標
    sa.str=sa.stc=0;
    sb.str=sb.stc=0;
    sc.str=sc.stc=0;
    // 長度賦值,因爲在subarr結構裏沒有長度的定義
    len=ROW;
    //申請空間
    c=(int**)malloc(sizeof(int *)*len);
    for (i=0;i<len;++i){
        c[i]=(int *)malloc(sizeof(int)*len);
    }
    // 給矩陣A,B 複製初始化
    createarr(a);
    createarr(b);
    //  進行運算
    c=square_recursive(sa,sb,sc,len);
    // 打印矩陣A,B,C
    printarray(a);
    printarray(b);
    print();
    return 0;
}

=========== 王傑 原創作品轉載請註明出處==============

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章