CUDA使用筆記(一)矩陣乘法

簡介

本文介紹cublasSgemm()函數的使用。在c/c++中,通常我們將2維矩陣按行存儲爲一維數組。但是在顯存中,矩


陣是按列存儲的。因此,我們在實際使用時,對cublasSgemm()中的各個參數的賦值可能會搞不清楚。


本文,以一個具體的矩陣乘法案例爲例子,介紹cublasSgemm()函數的使用。


正文:

我們以下圖所示的矩陣運算爲例進行講解。


因爲gpu顯存中,矩陣是按列進行存儲的。如果我們在gpu中直接計算F=D×E的話,那麼將結果矩陣從顯存中取出


放回內存中時,內存中實際存儲的則是FT


因此,如果我們能夠計算FT=ET× DT,那麼我們將結果取回內存時,內存中保存的剛好就是矩陣F


基於上述的解釋,我們在gpu中實際計算的是FT=ET× DT,取回內存後,剛好是我們需要的結果F


好的,下面我們說一下,怎麼使用cublasSgemm()函數計算FT=ET× DT

cublasSgemm()函數的輸入參數如下所示


cublasStatus_t cublasSgemm (
cublasHandle_t handle, 
cublasOperation_t transa,
cublasOperation_t transb, 
int m,
int n,
int k,
const float *alpha, /* host or device pointer */  
const float *A, 
int lda,
const float *B,
int ldb, 
const float *beta, /* host or device pointer */  
float *C,
int ldc);

上述函數直接計算的是C=alpha*A×B+beta*C


transa的值一般取CUBLAS_OP_NCUBLAS_OP_T分別表示顯存中A是否轉置;


transb的值同理。


mA矩陣的列,nB矩陣的行,kA矩陣的行(B矩陣的列)。


ldaldbldc分別表示ABC矩陣的列。





綜上所述,結合實際例子,我們可以知道各個參數的取值。

因爲我們計算的是FT=ET× DT,因此A矩陣取EB矩陣取D


transaCUBLAS_OP_NtransbCUBLAS_OP_Nalpha=1beta=0

E矩陣是2×3D矩陣是3×2,因此m=3E的列),n=3D的行),k=2E的行或者D的列。


lda =3E的列),ldb=2D的列),ldc=3F的列)。


實際調用情況如下

cublasSgemm(handle,CUBLAS_OP_N, CUBLAS_OP_N, 3, 3, 2, &alpha, E, 3, D, 2, &beta,F, 3)

完整測試程序如下,注意,部分參數名字可能與上面所述有出入。


#include <cublas_v2.h>  
#include <cuda.h> 
#include <cuda_runtime.h> 
#include <iostream>
#include <stdio.h>  
int main(void)  
{  
    float alpha=1.0;  
    float beta=0.0;  
    float h_D[6]={1,1,2,2,3,3};  
    float h_E[6]={1,2,3,4,5,6};  
    float h_F[9];  
    float *d_D,*d_E,*d_F;  
    cudaMalloc((void**)&d_D,6*sizeof(float));  
    cudaMalloc((void**)&d_E,6*sizeof(float));  
    cudaMalloc((void**)&d_F,9*sizeof(float));  
    cudaMemcpy(d_D,&h_D,6*sizeof(float),cudaMemcpyHostToDevice);  
    cudaMemcpy(d_E,&h_E,6*sizeof(float),cudaMemcpyHostToDevice);  
    cudaMemset(d_F,0,9*sizeof(float));  
    cublasHandle_t handle;  
    cublasCreate(&handle);  
    cublasSgemm(handle,CUBLAS_OP_N,CUBLAS_OP_N,3,3,2,&alpha,d_E,3,d_D,2,&beta,d_F,3);  
    cudaMemcpy(h_F,d_F,9*sizeof(float),cudaMemcpyDeviceToHost);  
    for(int i=0;i<9;i++)  
    {  
        printf("%f\n",h_F[i]);  
    }  
    printf("\n");  
    return 0;  
}







發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章