MNN 中的矩陣乘法

背景

之前也寫過sgemm,然後就想看看MNN是如何實現的,有沒有什麼可以借鑑的地方,看完之後發現MNN的實現也是簡單的按行數據並行處理,記錄一下。

矩陣乘法

矩陣乘法的目的是完成一個計算:C = A * B,其中A是h * k, B是k * w,所以C是h * w。
在這裏插入圖片描述
常用的方式是分行處理,對於C的第一行,可以按如下方式處理:

C(0,j) += A(0,i)*B(i,j)

對於行主序矩陣,每一行的數據是連續存儲的,我們自然可以考慮使用SIMD指令,一次處理4個(假設是Float32)數據的相乘:

float32x4_t a0   = vdupq_n_f32(aLine[i]);
float32x4_t b0   = vld1q_f32(bLine);
float32x4_t sum0 = vdupq_n_f32(0.0);
sum0             = vmlaq_f32(sum0, a0, b0);
vst1q_f32(cLine, sum0);

需要注意的一點是,如果w不能被4整除,那麼需要處理邊界,逐個點進行計算並賦值:

C(0,j) += A(0,i) * B(i,j)

MNN的實現

void Matrix::multi(Tensor* C, const Tensor* A, const Tensor* B) {
    MNN_ASSERT(NULL != C);
    MNN_ASSERT(NULL != B);
    MNN_ASSERT(NULL != A);

    MNN_ASSERT(2 == C->dimensions());
    MNN_ASSERT(2 == B->dimensions());
    MNN_ASSERT(2 == A->dimensions());

    const auto a = A->host<float>();
    const auto b = B->host<float>();
    auto c       = C->host<float>();

    const int h = A->length(0);
    const int k = A->length(1);
    const int w = B->length(1);

    const int aw = A->stride(0);
    const int bw = B->stride(0);
    const int cw = C->stride(0);

    MNN_ASSERT(k == B->length(0));

    int y = 0;
    for (; y < h; ++y) {
        int x            = 0;
        const auto aLine = a + y * aw;
        auto cLine       = c + y * cw;
#ifdef MNN_USE_NEON
        // firstly, compute 16 together
        for (; x <= w - 16; x += 16) {
            auto bColumn     = b + x;
            float32x4_t sum0 = vdupq_n_f32(0.0);
            float32x4_t sum1 = vdupq_n_f32(0.0);
            float32x4_t sum2 = vdupq_n_f32(0.0);
            float32x4_t sum3 = vdupq_n_f32(0.0);
            for (int i = 0; i < k; ++i) {
                const auto bLine = bColumn + i * bw;
                float32x4_t a0   = vdupq_n_f32(aLine[i]);
                float32x4_t b0   = vld1q_f32(bLine);
                float32x4_t b1   = vld1q_f32(bLine + 4);
                float32x4_t b2   = vld1q_f32(bLine + 8);
                float32x4_t b3   = vld1q_f32(bLine + 12);
                sum0             = vmlaq_f32(sum0, a0, b0);
                sum1             = vmlaq_f32(sum1, a0, b1);
                sum2             = vmlaq_f32(sum2, a0, b2);
                sum3             = vmlaq_f32(sum3, a0, b3);
            }
            vst1q_f32(cLine + x, sum0);
            vst1q_f32(cLine + x + 4, sum1);
            vst1q_f32(cLine + x + 8, sum2);
            vst1q_f32(cLine + x + 12, sum3);
        }
        // secondly, compute 4 together
        for (; x <= w - 4; x += 4) {
            auto bColumn    = b + x;
            float32x4_t sum = vdupq_n_f32(0.0);
            for (int i = 0; i < k; ++i) {
                const auto bLine = bColumn + i * bw;
                float32x4_t a4   = vdupq_n_f32(aLine[i]);
                float32x4_t b4   = vld1q_f32(bLine);
                sum              = vmlaq_f32(sum, a4, b4);
            }
            vst1q_f32(cLine + x, sum);
        }
#endif
        for (; x < w; ++x) {
            auto bColumn = b + x;
            float sum    = 0.0f;
            for (int i = 0; i < k; ++i) {
                sum += aLine[i] * bColumn[i * bw];
            }
            cLine[x] = sum;
        }
    }
}

關鍵部分是MNN_USE_NEON宏包裹的部分,具體的思路,對輸出矩陣C進行循環,因爲是行主序(每一行連續存儲),所以按行來進行計算,只不過它這裏,先按16循環,可以利用流水線,提升效率,然後對於小於16的部分,先4個一組處理,對於小於4的邊界部分,逐點處理。

評論

MNN的矩陣乘法實現可以說是標準的sgemm的一種簡單加速版本,但是有兩個點需要進一步考慮,一是沒有做pack,對於大矩陣,這種直接的vld1q_f32可能會導致大量的cache miss,二是既然已經按行分別處理,且每一行的寫入過程相互獨立,所以可以考慮增加多線程來提高行間效率,可以使用Openmp或者自己起兩個Thread來進行並行處理。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章