背景
之前也寫過sgemm,然後就想看看MNN是如何實現的,有沒有什麼可以借鑑的地方,看完之後發現MNN的實現也是簡單的按行數據並行處理,記錄一下。
矩陣乘法
矩陣乘法的目的是完成一個計算:C = A * B
,其中A是h * k, B是k * w,所以C是h * w。
常用的方式是分行處理,對於C的第一行,可以按如下方式處理:
C(0,j) += A(0,i)*B(i,j)
對於行主序矩陣,每一行的數據是連續存儲的,我們自然可以考慮使用SIMD指令,一次處理4個(假設是Float32)數據的相乘:
float32x4_t a0 = vdupq_n_f32(aLine[i]);
float32x4_t b0 = vld1q_f32(bLine);
float32x4_t sum0 = vdupq_n_f32(0.0);
sum0 = vmlaq_f32(sum0, a0, b0);
vst1q_f32(cLine, sum0);
需要注意的一點是,如果w不能被4整除,那麼需要處理邊界,逐個點進行計算並賦值:
C(0,j) += A(0,i) * B(i,j)
MNN的實現
void Matrix::multi(Tensor* C, const Tensor* A, const Tensor* B) {
MNN_ASSERT(NULL != C);
MNN_ASSERT(NULL != B);
MNN_ASSERT(NULL != A);
MNN_ASSERT(2 == C->dimensions());
MNN_ASSERT(2 == B->dimensions());
MNN_ASSERT(2 == A->dimensions());
const auto a = A->host<float>();
const auto b = B->host<float>();
auto c = C->host<float>();
const int h = A->length(0);
const int k = A->length(1);
const int w = B->length(1);
const int aw = A->stride(0);
const int bw = B->stride(0);
const int cw = C->stride(0);
MNN_ASSERT(k == B->length(0));
int y = 0;
for (; y < h; ++y) {
int x = 0;
const auto aLine = a + y * aw;
auto cLine = c + y * cw;
#ifdef MNN_USE_NEON
// firstly, compute 16 together
for (; x <= w - 16; x += 16) {
auto bColumn = b + x;
float32x4_t sum0 = vdupq_n_f32(0.0);
float32x4_t sum1 = vdupq_n_f32(0.0);
float32x4_t sum2 = vdupq_n_f32(0.0);
float32x4_t sum3 = vdupq_n_f32(0.0);
for (int i = 0; i < k; ++i) {
const auto bLine = bColumn + i * bw;
float32x4_t a0 = vdupq_n_f32(aLine[i]);
float32x4_t b0 = vld1q_f32(bLine);
float32x4_t b1 = vld1q_f32(bLine + 4);
float32x4_t b2 = vld1q_f32(bLine + 8);
float32x4_t b3 = vld1q_f32(bLine + 12);
sum0 = vmlaq_f32(sum0, a0, b0);
sum1 = vmlaq_f32(sum1, a0, b1);
sum2 = vmlaq_f32(sum2, a0, b2);
sum3 = vmlaq_f32(sum3, a0, b3);
}
vst1q_f32(cLine + x, sum0);
vst1q_f32(cLine + x + 4, sum1);
vst1q_f32(cLine + x + 8, sum2);
vst1q_f32(cLine + x + 12, sum3);
}
// secondly, compute 4 together
for (; x <= w - 4; x += 4) {
auto bColumn = b + x;
float32x4_t sum = vdupq_n_f32(0.0);
for (int i = 0; i < k; ++i) {
const auto bLine = bColumn + i * bw;
float32x4_t a4 = vdupq_n_f32(aLine[i]);
float32x4_t b4 = vld1q_f32(bLine);
sum = vmlaq_f32(sum, a4, b4);
}
vst1q_f32(cLine + x, sum);
}
#endif
for (; x < w; ++x) {
auto bColumn = b + x;
float sum = 0.0f;
for (int i = 0; i < k; ++i) {
sum += aLine[i] * bColumn[i * bw];
}
cLine[x] = sum;
}
}
}
關鍵部分是MNN_USE_NEON宏包裹的部分,具體的思路,對輸出矩陣C進行循環,因爲是行主序(每一行連續存儲),所以按行來進行計算,只不過它這裏,先按16循環,可以利用流水線,提升效率,然後對於小於16的部分,先4個一組處理,對於小於4的邊界部分,逐點處理。
評論
MNN的矩陣乘法實現可以說是標準的sgemm的一種簡單加速版本,但是有兩個點需要進一步考慮,一是沒有做pack,對於大矩陣,這種直接的vld1q_f32
可能會導致大量的cache miss,二是既然已經按行分別處理,且每一行的寫入過程相互獨立,所以可以考慮增加多線程來提高行間效率,可以使用Openmp或者自己起兩個Thread來進行並行處理。