LSTM的cuda加速

目的是雲端算法中執行LSTM部分計算過程的加速,即用cu文件編譯出so,用此so中的LSTM類或函數替代tf.LSTMCell進行運算。
整個項目見Github,流程見博客,博主也剛入門cuda,歡迎留言探討~

TensorFlow LSTM benchmark

測試代碼見Github文件夾tensorflow_LSTM_benchmark。

TensorFlow提供5種LSTM變體:(1)BasicLSTMCell,(2)LSTMCell,(3)LSTMBlockCell,(4)LSTMBlockFusedCell和(5)cuDNNLSTM

測試環境GTX1080Ti:seqLength 100, numLayers 1, hiddenSize 512, miniBatch 64 執行1000次取均值

Cell CPU GPU 分析(CPU and GPU)
BasicLSTMCell 195.74ms 46.89ms 時間步上用tf.dynamic_rnn循環,由於實現簡單,速度相對較快
LSTMCell 202.13ms 50.83ms 標準LSTM,有更多的參數選擇,速度相差不多
LSTMBlockCell 254.76ms 48.72ms 用於單運行RNN交互場景,時間步上和tf.while_loop結合使用,CPU上速度明顯慢,理論在GPU上應該比Basic快,實際速度相差不多
LSTMBlockFusedCell 190.48ms 23.06ms CPU上速度輕微加快,GPU如預期速度快很多
cuDNNLSTM / 18.63ms 在CPU上默認執行LSTMBlockFused,在GPU上是5種變體速度最快的
cudaLSTM / 11.36ms 自己寫的源碼如下,比tf自帶版本要快

計算過程

在這裏插入圖片描述需要寫個main函數調用和tf.LSTMCell對比一致性:


// Fused forward kernel
__global__ void elementWise_fp(int hiddenSize, int miniBatch,
                               float *tmp_h, 
                               float *tmp_i, 
                               float *bias,
                               float *linearGates,
                               float *h_out,
                               float *i_out,
                               float *c_in,
                               float *c_out,
                               bool training,
                               float *w_i_diag,
                               float *w_f_diag,
                               float *w_o_diag,
                               bool use_peepholes,
                               int num_proj,
                               float cell_clip) {
   
   int index = blockIdx.x * blockDim.x + threadIdx.x;
   int numElements = miniBatch * hiddenSize;
   if (index >= numElements) return;

   //index: [0, batch*hidden]
   int batch = index / hiddenSize;
   //大gate範圍[batch, 4*hiddenSize]利用i(行標) * N + j(列標) 而i可用index/col j可用 index%col 表示
   //index範圍batch*hiddenSize,則gateIndex代表每個門的首個batch的0-hiddenSize
   int gateIndex = batch * 4 *  hiddenSize;   
   int hid_index = index % hiddenSize;

   float g[4];
   // printf("tmp_h: %f, bias: %d\n", tmp_h[3 * hiddenSize + gateIndex + (index % hiddenSize)], bias[0 * hiddenSize + index % hiddenSize]);

   //add1: tmp_c = w * c_in
   for (int i = 0; i < 4; i++) {
      //每個門總大小4*hiddenSize, 門內訪問
      // printf("g[i] of i,j,f,o %f\n",g[i]);
      g[i] = tmp_i[i * hiddenSize + gateIndex + hid_index] + tmp_h[i * hiddenSize + gateIndex + hid_index];
      //bias總大小4*hiddenSize i代表4個門,每個門用一行hiddenSize個bias
      g[i] += bias[i * hiddenSize + hid_index];
      //+ bias[(i + 4) * hiddenSize + index % hiddenSize];

      // printf("g[i] of i,j,f,o %f\n",g[i]);
      if (training) linearGates[gateIndex + i * hiddenSize + hid_index] = g[i];
   }   
   
   // printf("index: %d, i:%f,j:%f,f:%f,o:%f\n", index, g[0], g[1],g[2], g[3]);

   // printf("%d\n", use_peepholes);

   float in_gate, in_gate2, forget_gate, out_gate;
   in_gate2 = tanh(g[1]);

   out_gate    = sigmoidf(g[3]);
   // printf("%f\n", out_gate);
   //如下運算都是基於numElements個線程計算,即每個大小均爲batch*hiddensize,i f j o

   if(use_peepholes == true){
      in_gate  = sigmoidf(g[0] + w_i_diag[hid_index] * c_in[index]);
      forget_gate    = sigmoidf(g[2] + w_f_diag[hid_index] * c_in[index]);//add3: 此門不加Ct-1 if(i == 2) else
   }
   else{
      in_gate     = sigmoidf(g[0]);
      forget_gate    = sigmoidf(g[2]);
   }


   
   //pivot
   //printf("sigmoid in_gate: %f\n", in_gate);
   //printf("sigmoid in_gate: %f\n", forget_gate);
   // printf("c_in[index] in_gate: %f\n", c_in[index]);

   //c = f * c_t-1+ i * c_hat if use peephole have changed in fgte and igate 
  

   float val = (forget_gate * c_in[index]) + (in_gate * in_gate2);
   

   if(cell_clip > 0.0){
    // printf("in cell_clip\n");
      if(val > cell_clip)
        val = cell_clip;
      if(val < -cell_clip)
        val = -cell_clip;
   } 

   // printf("index: %d c_in: %f\n", index, c_in[index]);
   c_out[index] = val;
   // printf("index: %d c_out: %f\n", index, c_out[index]);
   //h = o * tanh(c)
                

   if(use_peepholes == true){
      out_gate  = sigmoidf(g[3] + w_o_diag[hid_index] * c_out[index]);
      val = out_gate * tanhf(c_out[index]);                      
   }
   else
      val = out_gate * tanhf(val);


   // printf("index: %d, output_before_proj: %f\n", index, val);

   // if(index < miniBatch * num_proj)
   //     h_out[index] = val;
   // printf("%f\n", val);
   i_out[index] = val;
}


void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,
              float* input, float* c_data_in, float *h_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,
              float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* c_data_out, float* h_data_out, float* output,
              bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0){
    
    // cudaErrCheck(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
//    float* tmp_i, *tmp_h, *c_o_data + 2 * numElements, *w_i_diag, *w_f_diag, *w_o_diag, *proj_kernel, *c_i_data + 2 * miniBatch * h_depth, *input_T, *input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, *input_T + miniBatch * input_depth * seqLength,
//    *h_op_data, *h_op_data + miniBatch * h_depth, *c_data_in_cuda, *c_o_data + numElements, *bias;
    float *w_diag_bias_proj;

    
    
    float alpha = 1.f;
    float beta  = 0.f;
    int input_depth = inputSize;
    int gateSize = hiddenSize * 4;
    int h_depth;
    int numElements = miniBatch * hiddenSize;
    
    if(use_peepholes == true)
        h_depth = outSize;
    else
        h_depth = hiddenSize;

    int w_diag_bias_proj_size = (7 + h_depth) * hiddenSize;
    int h_op_data_size = 2 * miniBatch * h_depth + miniBatch * seqLength * h_depth;
    int input_T_size = miniBatch * seqLength * input_depth + (input_depth + h_depth) * gateSize;
    int c_o_data_size = 2 * miniBatch * hiddenSize + miniBatch * seqLength * hiddenSize;
    int tmp_i_size = miniBatch * seqLength * gateSize;
    int tmp_h_size = miniBatch * gateSize;
    
    printf("the seqLength is: %d, inputSize: %d, input_depth: %d, hiddenSize: %d, outSize: %d\n", seqLength, inputSize, input_depth, hiddenSize, outSize);
    
    cudaErrCheck(cudaGetLastError());

    //six to one
    // cudaErrCheck(cudaMalloc((void**)&tmp_i, miniBatch * seqLength * gateSize * sizeof(float)));
    // cudaErrCheck(cudaMalloc((void**)&tmp_h, miniBatch * gateSize * sizeof(float)));
    // cudaErrCheck(cudaMalloc((void**)&input_T, (miniBatch * input_depth * seqLength + input_depth * gateSize + h_depth * gateSize ) * sizeof(float)));
    // cudaErrCheck(cudaMalloc((void**)&h_op_data, miniBatch * (seqLength+2) * h_depth * sizeof(float)));
    // cudaErrCheck(cudaMalloc((void**)&c_o_data, miniBatch * (seqLength+2) * hiddenSize * sizeof(float)));
    cudaErrCheck(cudaMalloc((void**)&w_diag_bias_proj, (w_diag_bias_proj_size + input_T_size + h_op_data_size + c_o_data_size + tmp_i_size + tmp_h_size) * sizeof(float)));
    
    //b = a + size(a);
    float *input_T = w_diag_bias_proj + w_diag_bias_proj_size;
    //c = b + size(b);
    float *h_op_data = input_T + input_T_size;
    float *c_o_data = h_op_data + h_op_data_size;
    float *tmp_i = c_o_data + c_o_data_size;
    float *tmp_h = tmp_i + tmp_i_size;

    cudaStream_t stream_i, stream_h;
    cudaErrCheck(cudaStreamCreate(&stream_i));
    cudaErrCheck(cudaStreamCreate(&stream_h));
    bool stream_i_flag = true;

    //pivot
    // cudaErrCheck(cudaMemcpy(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
    // // printf("*************************%d\n", hiddenSize);
    // cudaErrCheck(cudaMemcpy(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice));

    
    // // printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);
    
    // cudaErrCheck(cudaMemcpy(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice));
    // cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
    

    cudaErrCheck(cudaMemcpyAsync(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice, stream_i));
    cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_i));
    cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    // printf("*************************%d\n", hiddenSize);
    cudaErrCheck(cudaMemcpyAsync(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    cudaErrCheck(cudaMemcpyAsync(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice, stream_h));

    
    // printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);
    
    cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
    

    
    
    cudaErrCheck(cudaGetLastError());
    
    // cudaDeviceSynchronize();
    // Need a cuBLAS handle.
    cublasHandle_t handle;
    cublasErrCheck(cublasCreate(&handle));


    cublasErrCheck(cublasSetStream(handle, stream_i));
    cublasErrCheck(cublasSgemm(handle,
                               CUBLAS_OP_N, CUBLAS_OP_N,
                               gateSize, miniBatch * seqLength, input_depth,
                               &alpha,
                               input_T + miniBatch * input_depth * seqLength,
                               gateSize,
                               input_T,
                               input_depth,
                               &beta,
                               tmp_i,
                               gateSize));
    
    
    
    cudaErrCheck(cudaGetLastError());
    
    // cudaEvent_t event1, event2;
    // cudaEventCreate(&event1);
    // cudaEventCreate(&event2);
    
    // printf("######tmp_i: \n");
    // float* tmp_i_cpu = init_Matrix_zeros(miniBatch * gateSize);
    // cudaMemcpy(tmp_i_cpu, tmp_i, miniBatch * gateSize * sizeof(float), cudaMemcpyDeviceToHost);
    // cudaDeviceSynchronize();
    // printMatix(tmp_i_cpu, miniBatch * gateSize);
    
    // cudaEventCreateWithFlags(&event1, cudaEventBlockingSync);
    // cudaEventCreateWithFlags(&event2, cudaEventBlockingSync);
    
    
    for(int i = 0; i < seqLength; ++i){
        // cudaEventRecord(event1, 0);
        cublasErrCheck(cublasSetStream(handle, stream_h));
        cublasErrCheck(cublasSgemm(handle,
                                   CUBLAS_OP_N, CUBLAS_OP_N,
                                   gateSize, miniBatch, h_depth,
                                   &alpha,
                                   input_T + miniBatch * input_depth * seqLength + input_depth * gateSize,
                                   gateSize,
                                   h_op_data,
                                   h_depth ,
                                   &beta,
                                   tmp_h,
                                   gateSize));
        
        dim3 blockDim;
        dim3 gridDim;
        
        blockDim.x = 256;
        gridDim.x = (miniBatch * hiddenSize + blockDim.x - 1) / blockDim.x;
      
        if(stream_i_flag == true)
            cudaErrCheck(cudaStreamSynchronize(stream_i));
        elementWise_fp <<< gridDim, blockDim, 0 >>>
        (hiddenSize, miniBatch,
         tmp_h,
         tmp_i + i * miniBatch * gateSize,
         w_diag_bias_proj + 3 * hiddenSize,
         NULL,
         h_op_data + miniBatch * h_depth,
         c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
         c_o_data,
         c_o_data + numElements,
         false,
         w_diag_bias_proj,
         w_diag_bias_proj + hiddenSize,
         w_diag_bias_proj + 2 * hiddenSize,
         use_peepholes,
         h_depth,
         cell_clip);

        if(stream_i_flag == true){
            cudaErrCheck(cudaStreamDestroy(stream_i));
            stream_i_flag = false;
         }

        cudaErrCheck(cudaGetLastError());
        if(use_peepholes != 0){
            cublasErrCheck(cublasSgemm(handle,
                                       CUBLAS_OP_N, CUBLAS_OP_N,
                                       h_depth, miniBatch, hiddenSize,
                                       &alpha,
                                       w_diag_bias_proj + 7 * hiddenSize,
                                       h_depth,
                                       c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
                                       hiddenSize,
                                       &beta,
                                       h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth,
                                       h_depth));
            if(proj_clip != 0){
                // printf("in proj_clip\n");
                dim3 blockDim;
                dim3 gridDim;
                blockDim.x = 256;
                gridDim.x = (h_depth * miniBatch + blockDim.x - 1) / blockDim.x;
                clip_by_value <<< gridDim, blockDim, 0 >>>
                (h_op_data + 2 * miniBatch * h_depth + i * h_depth * miniBatch, proj_clip, miniBatch * h_depth);
            }
            //h_data和i_data保持同步
        }
        // cudaEventRecord(event2, 0);
        
        // cudaEventSynchronize(event1);
        // cudaEventSynchronize(event2);
        // cudaDeviceSynchronize();
        
        cudaErrCheck(cudaMemcpy(h_op_data + miniBatch * h_depth, h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
        cudaErrCheck(cudaMemcpy(h_op_data, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
        cudaErrCheck(cudaMemcpy(c_o_data, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToDevice));
        cudaErrCheck(cudaGetLastError());
        
        
        // cudaErrCheck(cudaMemcpy(out_cpu1, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
        // cudaDeviceSynchronize();
        // printMatix(out_cpu1, miniBatch * h_depth);
        
        // float *c = (float*)malloc(sizeof(float)*miniBatch * h_depth);
        // float* out_cpu = new float[miniBatch * h_depth];
        // cudaErrCheck(cudaMemcpy(c, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
        // printMatix(c, miniBatch * h_depth);
        // printf("end for*************************%d\n", hiddenSize);
        cudaErrCheck(cudaGetLastError());
        
    }
    // cudaDeviceSynchronize();
    // printf("end 1       *************************%d\n", hiddenSize);
    
    // float* out_cpu = new float[seqLength * miniBatch * h_depth];
    // cudaErrCheck(cudaMemcpy(out_cpu, c_i_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
    // cudaDeviceSynchronize();
    // printMatix(out_cpu, seqLength * miniBatch * h_depth);
    
    cudaErrCheck(cudaMemcpy(h_data_out, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
    cudaErrCheck(cudaMemcpy(c_data_out, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToHost));
    cudaErrCheck(cudaMemcpy(output, h_op_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
    
    
    // printf("free        *************************%d\n", hiddenSize);
    //six to one
    // cudaErrCheck(cudaFree(tmp_i));
    // cudaErrCheck(cudaFree(tmp_h));
    // cudaErrCheck(cudaFree(c_o_data));
    // cudaErrCheck(cudaFree(input_T));
    // cudaErrCheck(cudaFree(h_op_data));
    cudaErrCheck(cudaFree(w_diag_bias_proj));
 
    cudaErrCheck(cudaStreamDestroy(stream_h));


    // printf("end  2      *************************%d\n", hiddenSize);
}  
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章