【darknet源碼解析-06】gemm.h和gemm.c解析

本系列爲darknet源碼解析,本次解析src/gemm.h 與src/gemm.c兩個。在上一篇文章中,我們已經詳細講解了輸入特徵圖如何進行轉換,那麼在本文中,gemm主要完成矢量和矩陣的加速運算,是darknet卷積底層實現的核心,其實也是caffe卷積實現的核心。

gemm.h 的包含的代碼如下:主要就是兩個函數的gemm,gemm_cpu的定義【gemm_bin暫不分析】,在這裏我們先不涉及gpu那塊,先講解cpu這塊的矩陣加速運算。

#ifndef GEMM_H
#define GEMM_H

void gemm_bin(int M, int N, int K, float ALPHA, 
        char  *A, int lda, 
        float *B, int ldb,
        float *C, int ldc);
        
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
                    float *A, int lda, 
                    float *B, int ldb,
                    float BETA,
                    float *C, int ldc);

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc);

#ifdef GPU
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A_gpu, int lda, 
        float *B_gpu, int ldb,
        float BETA,
        float *C_gpu, int ldc);

void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc);
#endif
#endif

gemm.c 的詳細分析如下,可以先看後面的白話總結描述以及小例子再來看源碼:

#include "gemm.h"
#include "utils.h"
#include "cuda.h"
#include <stdlib.h>
#include <stdio.h>
#include <math.h>

void gemm_bin(int M, int N, int K, float ALPHA, 
        char  *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            char A_PART = A[i*lda+k];
            if(A_PART){
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] += B[k*ldb+j];
                }
            } else {
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] -= B[k*ldb+j];
                }
            }
        }
    }
}

float *random_matrix(int rows, int cols)
{
    int i;
    float *m = calloc(rows*cols, sizeof(float));
    for(i = 0; i < rows*cols; ++i){
        m[i] = (float)rand()/RAND_MAX;
    }
    return m;
}

void time_random_matrix(int TA, int TB, int m, int k, int n)
{
    float *a;
    if(!TA) a = random_matrix(m,k);
    else a = random_matrix(k,m);
    int lda = (!TA)?k:m;
    float *b;
    if(!TB) b = random_matrix(k,n);
    else b = random_matrix(n,k);
    int ldb = (!TB)?n:k;

    float *c = random_matrix(m,n);
    int i;
    clock_t start = clock(), end;
    for(i = 0; i<10; ++i){
        gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
    }
    end = clock();
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
    free(a);
    free(b);
    free(c);
}

/**
 * gemm函數調用了gemm_cpu()函數,並且將參數原封不動的傳給gemm_cpu()
 */
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    gemm_cpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
}


/**
 * 被gemm_cpu函數調用,實際完成 C = ALPHA * A * B + C 矩陣運算,輸出的C也是按行存儲(所有行併成一行)
 * @param M A,C的行數(不做轉置)
 * @param N B,C的列數(不做裝置)
 * @param K A的列數,C的行數(不做轉置)
 * @param ALPHA 係數
 * @param A 輸入矩陣(一維數組格式)
 * @param lda A的列數(不做轉置)
 * @param B 輸入矩陣(一維數組格式)
 * @param ldb B的列數(不做轉置)
 * @param C 輸入矩陣(一維數組格式)
 * @param ldc C的列數(不做轉置)
 *
 * 說明:此函數在gemm_cpu()函數中調用,是其中四中情況之一,A不進行轉置,B不進行轉置
 *      函數名gemm_nt()中nt分別表示 not transpose, tranpose
 */
void gemm_nn(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{ // input: 矩陣A[M,K], filter: 矩陣B[K,N],  output: 矩陣C[M,N]
    int i,j,k;
    #pragma omp parallel for
    // 大循環:遍歷A的每一行,i表示A的第i行,也是C的第i行
    for(i = 0; i < M; ++i){
        // 中循環:遍歷每一行的所有列,k表示A的第k列,同時表示B的第k行
        for(k = 0; k < K; ++k){
                            // 先計算ALPHA * A (A中每一個元素乘以ALPHA)
            register float A_PART = ALPHA*A[i*lda+k];
            // 內循環:遍歷B中所有列,每次大循環完畢,將計算得到A×B一行的結果
            // j是B的第j列,也是C的第j列
            for(j = 0; j < N; ++j){
                // A中第i行k列與B中第k行i列對應相乘,因爲一個大循環要計算A×B一行的結果
                // 因此,這裏用一個內循環,並沒有直接乘以B[k*ldb+i]
                // 每個內循環完畢,將計算A×B整行的部分結果(A中第i行k列與B所有列第k行所有元素相乘的結果)
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}


/**
 * 被gemm_cpu()函數調用,實際完成 C = ALPHA * A * B^T + C 矩陣計算
 * @param M A,C的行數(不做轉置)或者A^T的行數(做轉置),此處A未轉置,故爲A的行數
 * @param N B,C的列數(不做轉置)或者B^T的列數(做轉置),此處B轉置,故爲B^T的列數
 * @param K  A的列數(不做轉置)或者A^T的列數(做轉置),B的行數(不做轉置)或者B^T(做轉置),此處A未轉置,B轉置,故爲A的列數,B^T的行數
 * @param ALPHA 係數
 * @param A 輸入矩陣
 * @param lda  A的列數(不做轉置)或者A^T的行數(做轉置),此處A未轉置,故爲A的列數
 * @param B 輸入矩陣
 * @param ldb B的列數(不做轉置)或者B^T的行數(做轉置),此處B轉置,故爲B^T的行數
 * @param C 輸入矩陣
 * @param ldc 矩陣C的列數
 * 說明:此函數在gemm_cpu()函數中調用,是其中四中情況之一,A不進行轉置,B轉置
 *      函數名gemm_nt()中nt分別表示 not transpose, tranpose
 */
void gemm_nt(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{// input: 矩陣A[M,K], filter: 矩陣B[K,N],  output: 矩陣C[M,N]
    int i,j,k;
    #pragma omp parallel for
    // 大循環:遍歷A的每一行,i表示A的第i行,也是C的第i行
    for(i = 0; i < M; ++i){

        for(j = 0; j < N; ++j){
            register float sum = 0;
            //內循環:每次內循環結束,將計算A中第i行與B中第j列相乘的結果
            //也就是得到C[i][j],因爲C也一維化,且按行存儲,所以得到C[i*lda+j]
            // k表示A的第幾列,也表示
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
            }
            C[i*ldc+j] += sum;
        }
    }
}

/**
 * 被gemm_cpu()函數調用,實際完成 C = ALPHA * A^T * B + C 矩陣計算
 * @param M A,C的行數(不做轉置)或者A^T的行數(做轉置),此處A轉置,故爲A^T的行數
 * @param N B,C的列數(不做轉置)或者B^T的列數(做轉置),此處B未轉置,故爲B的列數
 * @param K  A的列數(不做轉置)或者A^T的列數(做轉置),B的行數(不做轉置)或者B^T行數(做轉置),此處A未轉置,B轉置,故爲A^T的列數,B的行數
 * @param ALPHA 係數
 * @param A 輸入矩陣
 * @param lda  A的列數(不做轉置)或者A^T的行數(做轉置),此處A轉置,故爲A^T的行數
 * @param B 輸入矩陣
 * @param ldb B的列數(不做轉置)或者B^T的行數(做轉置),此處B未轉置,故爲B的列數
 * @param C 輸入矩陣
 * @param ldc 矩陣C的列數
 * 說明:此函數在gemm_cpu()函數中調用,是其中四中情況之一,A進行轉置,B不進行轉置
 *      函數名gemm_tn()中tn分別表示  tranpose,not transpose
 */
void gemm_tn(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[k*lda+i];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}


/**
 * 被gemm_cpu()函數調用,實際完成 C = ALPHA * A^T  * B^T + C 矩陣計算
 * @param M A,C的行數(不做轉置)或者A^T的行數(做轉置),此處A轉置,故爲A^T的行數
 * @param N B,C的列數(不做轉置)或者B^T的列數(做轉置),此處B轉置,故爲B^T的列數
 * @param K  A的列數(不做轉置)或者A^T的列數(做轉置),B的行數(不做轉置)或者B^T(做轉置),此處A轉置,B轉置,故爲A^T的列數,B^T的行數
 * @param ALPHA 係數
 * @param A 輸入矩陣
 * @param lda  A的列數(不做轉置)或者A^T的行數(做轉置),此處A轉置,故爲A^T的行數
 * @param B 輸入矩陣
 * @param ldb B的列數(不做轉置)或者B^T的行數(做轉置),此處B轉置,故爲B^T的行數
 * @param C 輸入矩陣
 * @param ldc 矩陣C的列數
 * 說明:此函數在gemm_cpu()函數中調用,是其中四中情況之一,A進行轉置,B進行轉置
 *      函數名gemm_tt()中tt分別表示 transpose, tranpose
 */
void gemm_tt(int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            register float sum = 0;
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
            }
            C[i*ldc+j] += sum;
        }
    }
}



/**
 * 矩陣計算,完成C = ALPHA * A * B + BETA * C 矩陣計算,最後的輸出爲C
 * @param TA 是否需要對A做轉置操作,是爲1,否爲0(要不要轉置取決於A,B之間的維度是否匹配,比如A:3*2, B:4*2, 則需要對B轉置,才滿足矩陣乘法)
 * @param TB 同上
 * @param M A,C 的行數(若A需要轉置,則此出給出轉置後A即A^T的行數,而不是轉置前的)
 * @param N B,C 的列數(若B需要轉置,則此處給出轉置後B即B^T的列數,而不是轉置前的)
 * @param K A的列數,B的行數(同樣,若A與B中的二者或者其中一個需要轉置,則不管怎麼樣,轉置後的A,B必須行列能夠匹配,符合矩陣乘法規則,K也是轉置後的值,不是轉置的)
 * @param ALPHA 係數
 * @param A 輸入矩陣
 * @param lda A的列數(不做轉置)或者行數(做轉置,且給的是轉置後A即A^T的行數)
 * @param B 輸入矩陣
 * @param ldb B的列數(不做轉置)或者行數(做轉置,且給的是轉置後B即B^T的行數)
 * @param BETA 係數
 * @param C 輸入矩陣
 * @param ldc C的列數
 */
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A, int lda, 
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
    int i, j;
    // 先行計算BETA * C,並把結果存入C中,得到C將爲M行N列(按行存儲在一維數組中)
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB) // TA = 0, TB = 0,
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB) // TA = 1, TB = 0
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB) // TA = 0, TB = 1
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else // TA = 1, TB = 1
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}



#ifdef GPU

#include <math.h>


/**
 * 矩陣計算GPU實現,調用CUDA中cublasSgemm()函數完成 C_gpu = ALPHA + A_gpu * B_gpu + BETA * C_gpu的線性矩陣運算,
 * 與gemm_cpu()基本類似,輸入參數也基本相同,但是存在兩點不同:
 * 1. 此處是直接調用CUDA cuBLAS庫中的cublasSgemm()函數進行矩陣運算,而無需gemm_cpu()那樣,需要自己用循環挨個元素相乘實現;
 * 2. 在GPU中,默認採用的矩陣存儲格式是按列存儲,而不是我們之前一度習慣的按行存儲,此處調用的cublasSgemm()也不例外,
 *    所以下面會有一些不同的操作(由於這個原因,相比於cpu版本的gemm_cpu(),又要複雜一些)
 *
 *
 *
 *
 * GPU使用cuBLAS庫中cublasSgemm()函數進行矩陣乘法計算,參看:
 * 這個網址是CUDA關於cuBLAS庫的官方文檔,此處cublasSgemm()函數在2.7.1節: cublas<t>gemm();
 * 可以看出cublasSgem()函數完成C_gpu = ALPHA * A_gpu * B_gpu + BETA * C_gpu的線性矩陣計算
 *
 * @param TA 是否需要對A做轉置操作,是爲1,否爲0(要不要轉置取決於A,B之間的維度是否匹配,比如A:3*2, B:4*2, 則需要對B轉置,才滿足矩陣乘法)
 * @param TB 同上
 * @param M A,C 的行數(若A需要轉置,則此出給出轉置後A即A^T的行數,而不是轉置前的)
 * @param N B,C 的列數(若B需要轉置,則此處給出轉置後B即B^T的列數,而不是轉置前的)
 * @param K A的列數,B的行數(同樣,若A與B中的二者或者其中一個需要轉置,則不管怎麼樣,轉置後的A,B必須行列能夠匹配,符合矩陣乘法規則,K也是轉置後的值,不是轉置的)
 * @param ALPHA 係數
 * @param A_gpu 輸入矩陣,且其內存在GPU設備內存中,不在主機內存中(由cudaMalloc分配,由cudaFree釋放)
 * @param lda A的列數(不做轉置)或者行數(做轉置,且給的是轉置後A即A^T的行數)
 * @param B_gpu 輸入矩陣,且其內存在GPU設備內存中,不在主機內存中(由cudaMalloc分配,由cudaFree釋放)
 * @param ldb B的列數(不做轉置)或者行數(做轉置,且給的是轉置後B即B^T的行數)
 * @param BETA 係數
 * @param C_gpu 輸入矩陣,且其內存在GPU設備內存中,不在主機內存中(由cudaMalloc分配,由cudaFree釋放)
 * @param ldc C的列數
 *
 * 可以看出,如果不是因爲存儲方式的不同,cublasSgemm()函數的結構也與darknet自己實現的cpu版本的gemm_cpu一模一樣;
 * 因爲二者存儲格式不同,需要交換A_gpu, B_gpu的位置,對應M和N之間,TB與TA之間,ldb與lda之間都要相互交換;
 *
 */
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, 
        float *A_gpu, int lda, 
        float *B_gpu, int ldb,
        float BETA,
        float *C_gpu, int ldc)
{
    //根據官網,這個變量是一個對開發者不透明的變量,也就是裏面聚義包給什麼,開發這一般無法知道,
    //只知道里麪包含的cuBLAS庫的相關信息,且這個變量是必須的,按照官網的描述,CUBLAS庫中所有的函數都需要這個變量參數
    //(且都是作爲第一個參數),該變量由cublasCreate()初始化,並由cuBLASDestroy()銷燬。
    cublasHandle_t handle = blas_handle();

    /* cublasSgemm()函數輸入參數說明
     * @param handle
     * @param transa 是否需要轉置A_gpu, 這裏transa = TB ? CUBLAS_OP_T : CUBLAS_OP_N (是個條件表達式),如果TB =1,
     *               則取CUBLAS_OP_T,即需要對A_gpu轉置;
     * @param transb 是否需要轉置A_gpu, 這裏transa = TA ? CUBLAS_OP_T : CUBLAS_OP_N (是個條件表達式),如果TA =1,
     *               則取CUBLAS_OP_T,即需要對B_gpu轉置;
     * @param M A_gpu,C_gpu 的行數(若A_gpu需要轉置,則此出給出轉置後A_gpu即A_gpu^T的行數,而不是轉置前的)
     * @param N B_gpu,C_gpu 的列數(若B_gpu需要轉置,則此處給出轉置後B_gpu即B_gpu^T的列數,而不是轉置前的)
     * @param K A_gpu的列數,B_gpu的行數(同樣,若A_gpu與B_gpu中的二者或者其中一個需要轉置,則不管怎麼樣,轉置後的A_gpu,B_gpu必須
     *          行列能夠匹配,符合矩陣乘法規則,K也是轉置後的值,不是轉置的)
     * @param ALPHA 實數係數
     * @param B_gpu 輸入矩陣
     * @param ldb B_gpu的列數(不做轉置)或者行數(做轉置,且傳入的是轉置後B_gpu即B_gpu^T的行數)
     * @param A_gpu 輸入矩陣
     * @param lda A_gpu的列數(不做轉置)或者行數(做轉置,且傳入的是轉置後A_gpu即A_gpu^T的行數)
     * @param BETA 實數係數
     * @param C_gpu 計算結果
     * @param ldc C_gpu的列數
     *
     */
    cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N),
            (TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
            // 檢查cublasSgemm運算是否正常
    check_error(status);
}

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>

void time_gpu_random_matrix(int TA, int TB, int m, int k, int n)
{
    float *a;
    if(!TA) a = random_matrix(m,k);
    else a = random_matrix(k,m);
    int lda = (!TA)?k:m;
    float *b;
    if(!TB) b = random_matrix(k,n);
    else b = random_matrix(n,k);
    int ldb = (!TB)?n:k;

    float *c = random_matrix(m,n);
    int i;
    clock_t start = clock(), end;
    for(i = 0; i<32; ++i){
        gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
    }
    end = clock();
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
    free(a);
    free(b);
    free(c);
}

void time_gpu(int TA, int TB, int m, int k, int n)
{
    int iter = 10;
    float *a = random_matrix(m,k);
    float *b = random_matrix(k,n);

    int lda = (!TA)?k:m;
    int ldb = (!TB)?n:k;

    float *c = random_matrix(m,n);

    float *a_cl = cuda_make_array(a, m*k);
    float *b_cl = cuda_make_array(b, k*n);
    float *c_cl = cuda_make_array(c, m*n);

    int i;
    clock_t start = clock(), end;
    for(i = 0; i<iter; ++i){
        gemm_gpu(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
        cudaThreadSynchronize();
    }
    double flop = ((double)m)*n*(2.*k + 2.)*iter;
    double gflop = flop/pow(10., 9);
    end = clock();
    double seconds = sec(end-start);
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s, %lf GFLOPS\n",m,k,k,n, TA, TB, seconds, gflop/seconds);
    cuda_free(a_cl);
    cuda_free(b_cl);
    cuda_free(c_cl);
    free(a);
    free(b);
    free(c);
}


void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
{
    srand(0);
    float *a;
    if(!TA) a = random_matrix(m,k);
    else a = random_matrix(k,m);
    int lda = (!TA)?k:m;
    float *b;
    if(!TB) b = random_matrix(k,n);
    else b = random_matrix(n,k);
    int ldb = (!TB)?n:k;

    float *c = random_matrix(m,n);
    float *c_gpu = random_matrix(m,n);
    memset(c, 0, m*n*sizeof(float));
    memset(c_gpu, 0, m*n*sizeof(float));
    int i;
    //pm(m,k,b);
    gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n);
    //printf("GPU\n");
    //pm(m, n, c_gpu);

    gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
    //printf("\n\nCPU\n");
    //pm(m, n, c);
    double sse = 0;
    for(i = 0; i < m*n; ++i) {
        //printf("%f %f\n", c[i], c_gpu[i]);
        sse += pow(c[i]-c_gpu[i], 2);
    }
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g SSE\n",m,k,k,n, TA, TB, sse/(m*n));
    free(a);
    free(b);
    free(c);
    free(c_gpu);
}

int test_gpu_blas()
{
    /*
       test_gpu_accuracy(0,0,10,576,75); 

       test_gpu_accuracy(0,0,17,10,10); 
       test_gpu_accuracy(1,0,17,10,10); 
       test_gpu_accuracy(0,1,17,10,10); 
       test_gpu_accuracy(1,1,17,10,10); 

       test_gpu_accuracy(0,0,1000,10,100); 
       test_gpu_accuracy(1,0,1000,10,100); 
       test_gpu_accuracy(0,1,1000,10,100); 
       test_gpu_accuracy(1,1,1000,10,100); 

       test_gpu_accuracy(0,0,10,10,10); 

       time_gpu(0,0,64,2916,363); 
       time_gpu(0,0,64,2916,363); 
       time_gpu(0,0,64,2916,363); 
       time_gpu(0,0,192,729,1600); 
       time_gpu(0,0,384,196,1728); 
       time_gpu(0,0,256,196,3456); 
       time_gpu(0,0,256,196,2304); 
       time_gpu(0,0,128,4096,12544); 
       time_gpu(0,0,128,4096,4096); 
     */
    time_gpu(0,0,64,75,12544); 
    time_gpu(0,0,64,75,12544); 
    time_gpu(0,0,64,75,12544); 
    time_gpu(0,0,64,576,12544); 
    time_gpu(0,0,256,2304,784); 
    time_gpu(1,1,2304,256,784); 
    time_gpu(0,0,512,4608,196); 
    time_gpu(1,1,4608,512,196); 

    return 0;
}
#endif

其實,gemm總結起來就完成一個矩陣乘法的運算:C = ALPHA * A * B + BETA * C

上述公式中,A,B,C爲矩陣,A,B爲輸入矩陣,C矩陣保存運算結果。ALPHA,BETA爲係數。這樣看起來是不是很簡單,接下來我們需要考慮矩陣A,B的行數和列數分別是多少,這裏我們假設矩陣A爲[M,K],矩陣B爲[K,N],那麼矩陣C爲[M,N]。我們都直到矩陣A,B,C在邏輯是一個二維的結果,在這裏實際的存儲結構是一個一維數組,按行存儲。

接下來,我們來具體看一個例子,爲了方便運算,我們這裏假設ALPHA=1,BETA=0。實際上我們只對A*B進行運算,我們分爲四種情況進行討論,爲什麼要分爲四種情況呢?其實就是引入矩陣的轉置。

1. A * B

 

 

 綜合計算一下,矩陣C的內容如下:

 

2. A^T * B

 

 其實這跟A* B h很想象。

 

3. A * B^T

 綜合一下,便可以得到結果

 

 

4 A^T * B^T

 跟A * B^T很想象。

 ok,此時你再取看源碼豁然開朗。

gemm_nn 函數就是計算 A * B這種類型;

gemm_tn 函數就是計算A^T * B這種類型;

gemm_nt 函數就是計算A * B^T這種類型;

gemm_tt 函數就是計算A^T * B^T 這種類型;

gemm_cpu 函數就是根據矩陣A和B的情況來實際調用 gemm_nn 、gemm_tn 、gemm_nt、gemm_tt 函數;

gemm 函數其實就是在gemm_cpu函數上再封裝一層,參數原封不動傳遞給gemm_cpu函數;

完,

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章