目的是雲端算法中執行LSTM部分計算過程的加速,即用cu文件編譯出so,用此so中的LSTM類或函數替代tf.LSTMCell進行運算。
整個項目見Github,流程見博客,博主也剛入門cuda,歡迎留言探討~
TensorFlow LSTM benchmark
測試代碼見Github文件夾tensorflow_LSTM_benchmark。
TensorFlow提供5種LSTM變體:(1)BasicLSTMCell,(2)LSTMCell,(3)LSTMBlockCell,(4)LSTMBlockFusedCell和(5)cuDNNLSTM
測試環境GTX1080Ti:seqLength 100, numLayers 1, hiddenSize 512, miniBatch 64 執行1000次取均值
Cell | CPU | GPU | 分析(CPU and GPU) |
---|---|---|---|
BasicLSTMCell | 195.74ms | 46.89ms | 時間步上用tf.dynamic_rnn循環,由於實現簡單,速度相對較快 |
LSTMCell | 202.13ms | 50.83ms | 標準LSTM,有更多的參數選擇,速度相差不多 |
LSTMBlockCell | 254.76ms | 48.72ms | 用於單運行RNN交互場景,時間步上和tf.while_loop結合使用,CPU上速度明顯慢,理論在GPU上應該比Basic快,實際速度相差不多 |
LSTMBlockFusedCell | 190.48ms | 23.06ms | CPU上速度輕微加快,GPU如預期速度快很多 |
cuDNNLSTM | / | 18.63ms | 在CPU上默認執行LSTMBlockFused,在GPU上是5種變體速度最快的 |
cudaLSTM | / | 11.36ms | 自己寫的源碼如下,比tf自帶版本要快 |
計算過程
需要寫個main函數調用和tf.LSTMCell對比一致性:
// Fused forward kernel
__global__ void elementWise_fp(int hiddenSize, int miniBatch,
float *tmp_h,
float *tmp_i,
float *bias,
float *linearGates,
float *h_out,
float *i_out,
float *c_in,
float *c_out,
bool training,
float *w_i_diag,
float *w_f_diag,
float *w_o_diag,
bool use_peepholes,
int num_proj,
float cell_clip) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int numElements = miniBatch * hiddenSize;
if (index >= numElements) return;
//index: [0, batch*hidden]
int batch = index / hiddenSize;
//大gate範圍[batch, 4*hiddenSize]利用i(行標) * N + j(列標) 而i可用index/col j可用 index%col 表示
//index範圍batch*hiddenSize,則gateIndex代表每個門的首個batch的0-hiddenSize
int gateIndex = batch * 4 * hiddenSize;
int hid_index = index % hiddenSize;
float g[4];
// printf("tmp_h: %f, bias: %d\n", tmp_h[3 * hiddenSize + gateIndex + (index % hiddenSize)], bias[0 * hiddenSize + index % hiddenSize]);
//add1: tmp_c = w * c_in
for (int i = 0; i < 4; i++) {
//每個門總大小4*hiddenSize, 門內訪問
// printf("g[i] of i,j,f,o %f\n",g[i]);
g[i] = tmp_i[i * hiddenSize + gateIndex + hid_index] + tmp_h[i * hiddenSize + gateIndex + hid_index];
//bias總大小4*hiddenSize i代表4個門,每個門用一行hiddenSize個bias
g[i] += bias[i * hiddenSize + hid_index];
//+ bias[(i + 4) * hiddenSize + index % hiddenSize];
// printf("g[i] of i,j,f,o %f\n",g[i]);
if (training) linearGates[gateIndex + i * hiddenSize + hid_index] = g[i];
}
// printf("index: %d, i:%f,j:%f,f:%f,o:%f\n", index, g[0], g[1],g[2], g[3]);
// printf("%d\n", use_peepholes);
float in_gate, in_gate2, forget_gate, out_gate;
in_gate2 = tanh(g[1]);
out_gate = sigmoidf(g[3]);
// printf("%f\n", out_gate);
//如下運算都是基於numElements個線程計算,即每個大小均爲batch*hiddensize,i f j o
if(use_peepholes == true){
in_gate = sigmoidf(g[0] + w_i_diag[hid_index] * c_in[index]);
forget_gate = sigmoidf(g[2] + w_f_diag[hid_index] * c_in[index]);//add3: 此門不加Ct-1 if(i == 2) else
}
else{
in_gate = sigmoidf(g[0]);
forget_gate = sigmoidf(g[2]);
}
//pivot
//printf("sigmoid in_gate: %f\n", in_gate);
//printf("sigmoid in_gate: %f\n", forget_gate);
// printf("c_in[index] in_gate: %f\n", c_in[index]);
//c = f * c_t-1+ i * c_hat if use peephole have changed in fgte and igate
float val = (forget_gate * c_in[index]) + (in_gate * in_gate2);
if(cell_clip > 0.0){
// printf("in cell_clip\n");
if(val > cell_clip)
val = cell_clip;
if(val < -cell_clip)
val = -cell_clip;
}
// printf("index: %d c_in: %f\n", index, c_in[index]);
c_out[index] = val;
// printf("index: %d c_out: %f\n", index, c_out[index]);
//h = o * tanh(c)
if(use_peepholes == true){
out_gate = sigmoidf(g[3] + w_o_diag[hid_index] * c_out[index]);
val = out_gate * tanhf(c_out[index]);
}
else
val = out_gate * tanhf(val);
// printf("index: %d, output_before_proj: %f\n", index, val);
// if(index < miniBatch * num_proj)
// h_out[index] = val;
// printf("%f\n", val);
i_out[index] = val;
}
void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,
float* input, float* c_data_in, float *h_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,
float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* c_data_out, float* h_data_out, float* output,
bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0){
// cudaErrCheck(cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
// float* tmp_i, *tmp_h, *c_o_data + 2 * numElements, *w_i_diag, *w_f_diag, *w_o_diag, *proj_kernel, *c_i_data + 2 * miniBatch * h_depth, *input_T, *input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, *input_T + miniBatch * input_depth * seqLength,
// *h_op_data, *h_op_data + miniBatch * h_depth, *c_data_in_cuda, *c_o_data + numElements, *bias;
float *w_diag_bias_proj;
float alpha = 1.f;
float beta = 0.f;
int input_depth = inputSize;
int gateSize = hiddenSize * 4;
int h_depth;
int numElements = miniBatch * hiddenSize;
if(use_peepholes == true)
h_depth = outSize;
else
h_depth = hiddenSize;
int w_diag_bias_proj_size = (7 + h_depth) * hiddenSize;
int h_op_data_size = 2 * miniBatch * h_depth + miniBatch * seqLength * h_depth;
int input_T_size = miniBatch * seqLength * input_depth + (input_depth + h_depth) * gateSize;
int c_o_data_size = 2 * miniBatch * hiddenSize + miniBatch * seqLength * hiddenSize;
int tmp_i_size = miniBatch * seqLength * gateSize;
int tmp_h_size = miniBatch * gateSize;
printf("the seqLength is: %d, inputSize: %d, input_depth: %d, hiddenSize: %d, outSize: %d\n", seqLength, inputSize, input_depth, hiddenSize, outSize);
cudaErrCheck(cudaGetLastError());
//six to one
// cudaErrCheck(cudaMalloc((void**)&tmp_i, miniBatch * seqLength * gateSize * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&tmp_h, miniBatch * gateSize * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&input_T, (miniBatch * input_depth * seqLength + input_depth * gateSize + h_depth * gateSize ) * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&h_op_data, miniBatch * (seqLength+2) * h_depth * sizeof(float)));
// cudaErrCheck(cudaMalloc((void**)&c_o_data, miniBatch * (seqLength+2) * hiddenSize * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&w_diag_bias_proj, (w_diag_bias_proj_size + input_T_size + h_op_data_size + c_o_data_size + tmp_i_size + tmp_h_size) * sizeof(float)));
//b = a + size(a);
float *input_T = w_diag_bias_proj + w_diag_bias_proj_size;
//c = b + size(b);
float *h_op_data = input_T + input_T_size;
float *c_o_data = h_op_data + h_op_data_size;
float *tmp_i = c_o_data + c_o_data_size;
float *tmp_h = tmp_i + tmp_i_size;
cudaStream_t stream_i, stream_h;
cudaErrCheck(cudaStreamCreate(&stream_i));
cudaErrCheck(cudaStreamCreate(&stream_h));
bool stream_i_flag = true;
//pivot
// cudaErrCheck(cudaMemcpy(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice));
// // printf("*************************%d\n", hiddenSize);
// cudaErrCheck(cudaMemcpy(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice));
// // printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice));
// cudaErrCheck(cudaMemcpy(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice));
cudaErrCheck(cudaMemcpyAsync(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_i));
cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
// printf("*************************%d\n", hiddenSize);
cudaErrCheck(cudaMemcpyAsync(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice, stream_h));
// printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));
cudaErrCheck(cudaGetLastError());
// cudaDeviceSynchronize();
// Need a cuBLAS handle.
cublasHandle_t handle;
cublasErrCheck(cublasCreate(&handle));
cublasErrCheck(cublasSetStream(handle, stream_i));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch * seqLength, input_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength,
gateSize,
input_T,
input_depth,
&beta,
tmp_i,
gateSize));
cudaErrCheck(cudaGetLastError());
// cudaEvent_t event1, event2;
// cudaEventCreate(&event1);
// cudaEventCreate(&event2);
// printf("######tmp_i: \n");
// float* tmp_i_cpu = init_Matrix_zeros(miniBatch * gateSize);
// cudaMemcpy(tmp_i_cpu, tmp_i, miniBatch * gateSize * sizeof(float), cudaMemcpyDeviceToHost);
// cudaDeviceSynchronize();
// printMatix(tmp_i_cpu, miniBatch * gateSize);
// cudaEventCreateWithFlags(&event1, cudaEventBlockingSync);
// cudaEventCreateWithFlags(&event2, cudaEventBlockingSync);
for(int i = 0; i < seqLength; ++i){
// cudaEventRecord(event1, 0);
cublasErrCheck(cublasSetStream(handle, stream_h));
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
gateSize, miniBatch, h_depth,
&alpha,
input_T + miniBatch * input_depth * seqLength + input_depth * gateSize,
gateSize,
h_op_data,
h_depth ,
&beta,
tmp_h,
gateSize));
dim3 blockDim;
dim3 gridDim;
blockDim.x = 256;
gridDim.x = (miniBatch * hiddenSize + blockDim.x - 1) / blockDim.x;
if(stream_i_flag == true)
cudaErrCheck(cudaStreamSynchronize(stream_i));
elementWise_fp <<< gridDim, blockDim, 0 >>>
(hiddenSize, miniBatch,
tmp_h,
tmp_i + i * miniBatch * gateSize,
w_diag_bias_proj + 3 * hiddenSize,
NULL,
h_op_data + miniBatch * h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
c_o_data,
c_o_data + numElements,
false,
w_diag_bias_proj,
w_diag_bias_proj + hiddenSize,
w_diag_bias_proj + 2 * hiddenSize,
use_peepholes,
h_depth,
cell_clip);
if(stream_i_flag == true){
cudaErrCheck(cudaStreamDestroy(stream_i));
stream_i_flag = false;
}
cudaErrCheck(cudaGetLastError());
if(use_peepholes != 0){
cublasErrCheck(cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
h_depth, miniBatch, hiddenSize,
&alpha,
w_diag_bias_proj + 7 * hiddenSize,
h_depth,
c_o_data + 2 * numElements + i * miniBatch * hiddenSize,
hiddenSize,
&beta,
h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth,
h_depth));
if(proj_clip != 0){
// printf("in proj_clip\n");
dim3 blockDim;
dim3 gridDim;
blockDim.x = 256;
gridDim.x = (h_depth * miniBatch + blockDim.x - 1) / blockDim.x;
clip_by_value <<< gridDim, blockDim, 0 >>>
(h_op_data + 2 * miniBatch * h_depth + i * h_depth * miniBatch, proj_clip, miniBatch * h_depth);
}
//h_data和i_data保持同步
}
// cudaEventRecord(event2, 0);
// cudaEventSynchronize(event1);
// cudaEventSynchronize(event2);
// cudaDeviceSynchronize();
cudaErrCheck(cudaMemcpy(h_op_data + miniBatch * h_depth, h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(h_op_data, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(c_o_data, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaGetLastError());
// cudaErrCheck(cudaMemcpy(out_cpu1, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// cudaDeviceSynchronize();
// printMatix(out_cpu1, miniBatch * h_depth);
// float *c = (float*)malloc(sizeof(float)*miniBatch * h_depth);
// float* out_cpu = new float[miniBatch * h_depth];
// cudaErrCheck(cudaMemcpy(c, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// printMatix(c, miniBatch * h_depth);
// printf("end for*************************%d\n", hiddenSize);
cudaErrCheck(cudaGetLastError());
}
// cudaDeviceSynchronize();
// printf("end 1 *************************%d\n", hiddenSize);
// float* out_cpu = new float[seqLength * miniBatch * h_depth];
// cudaErrCheck(cudaMemcpy(out_cpu, c_i_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// cudaDeviceSynchronize();
// printMatix(out_cpu, seqLength * miniBatch * h_depth);
cudaErrCheck(cudaMemcpy(h_data_out, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(c_data_out, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(output, h_op_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));
// printf("free *************************%d\n", hiddenSize);
//six to one
// cudaErrCheck(cudaFree(tmp_i));
// cudaErrCheck(cudaFree(tmp_h));
// cudaErrCheck(cudaFree(c_o_data));
// cudaErrCheck(cudaFree(input_T));
// cudaErrCheck(cudaFree(h_op_data));
cudaErrCheck(cudaFree(w_diag_bias_proj));
cudaErrCheck(cudaStreamDestroy(stream_h));
// printf("end 2 *************************%d\n", hiddenSize);
}