背景
之前也写过sgemm,然后就想看看MNN是如何实现的,有没有什么可以借鉴的地方,看完之后发现MNN的实现也是简单的按行数据并行处理,记录一下。
矩阵乘法
矩阵乘法的目的是完成一个计算:C = A * B
,其中A是h * k, B是k * w,所以C是h * w。
常用的方式是分行处理,对于C的第一行,可以按如下方式处理:
C(0,j) += A(0,i)*B(i,j)
对于行主序矩阵,每一行的数据是连续存储的,我们自然可以考虑使用SIMD指令,一次处理4个(假设是Float32)数据的相乘:
float32x4_t a0 = vdupq_n_f32(aLine[i]);
float32x4_t b0 = vld1q_f32(bLine);
float32x4_t sum0 = vdupq_n_f32(0.0);
sum0 = vmlaq_f32(sum0, a0, b0);
vst1q_f32(cLine, sum0);
需要注意的一点是,如果w不能被4整除,那么需要处理边界,逐个点进行计算并赋值:
C(0,j) += A(0,i) * B(i,j)
MNN的实现
void Matrix::multi(Tensor* C, const Tensor* A, const Tensor* B) {
MNN_ASSERT(NULL != C);
MNN_ASSERT(NULL != B);
MNN_ASSERT(NULL != A);
MNN_ASSERT(2 == C->dimensions());
MNN_ASSERT(2 == B->dimensions());
MNN_ASSERT(2 == A->dimensions());
const auto a = A->host<float>();
const auto b = B->host<float>();
auto c = C->host<float>();
const int h = A->length(0);
const int k = A->length(1);
const int w = B->length(1);
const int aw = A->stride(0);
const int bw = B->stride(0);
const int cw = C->stride(0);
MNN_ASSERT(k == B->length(0));
int y = 0;
for (; y < h; ++y) {
int x = 0;
const auto aLine = a + y * aw;
auto cLine = c + y * cw;
#ifdef MNN_USE_NEON
// firstly, compute 16 together
for (; x <= w - 16; x += 16) {
auto bColumn = b + x;
float32x4_t sum0 = vdupq_n_f32(0.0);
float32x4_t sum1 = vdupq_n_f32(0.0);
float32x4_t sum2 = vdupq_n_f32(0.0);
float32x4_t sum3 = vdupq_n_f32(0.0);
for (int i = 0; i < k; ++i) {
const auto bLine = bColumn + i * bw;
float32x4_t a0 = vdupq_n_f32(aLine[i]);
float32x4_t b0 = vld1q_f32(bLine);
float32x4_t b1 = vld1q_f32(bLine + 4);
float32x4_t b2 = vld1q_f32(bLine + 8);
float32x4_t b3 = vld1q_f32(bLine + 12);
sum0 = vmlaq_f32(sum0, a0, b0);
sum1 = vmlaq_f32(sum1, a0, b1);
sum2 = vmlaq_f32(sum2, a0, b2);
sum3 = vmlaq_f32(sum3, a0, b3);
}
vst1q_f32(cLine + x, sum0);
vst1q_f32(cLine + x + 4, sum1);
vst1q_f32(cLine + x + 8, sum2);
vst1q_f32(cLine + x + 12, sum3);
}
// secondly, compute 4 together
for (; x <= w - 4; x += 4) {
auto bColumn = b + x;
float32x4_t sum = vdupq_n_f32(0.0);
for (int i = 0; i < k; ++i) {
const auto bLine = bColumn + i * bw;
float32x4_t a4 = vdupq_n_f32(aLine[i]);
float32x4_t b4 = vld1q_f32(bLine);
sum = vmlaq_f32(sum, a4, b4);
}
vst1q_f32(cLine + x, sum);
}
#endif
for (; x < w; ++x) {
auto bColumn = b + x;
float sum = 0.0f;
for (int i = 0; i < k; ++i) {
sum += aLine[i] * bColumn[i * bw];
}
cLine[x] = sum;
}
}
}
关键部分是MNN_USE_NEON宏包裹的部分,具体的思路,对输出矩阵C进行循环,因为是行主序(每一行连续存储),所以按行来进行计算,只不过它这里,先按16循环,可以利用流水线,提升效率,然后对于小于16的部分,先4个一组处理,对于小于4的边界部分,逐点处理。
评论
MNN的矩阵乘法实现可以说是标准的sgemm的一种简单加速版本,但是有两个点需要进一步考虑,一是没有做pack,对于大矩阵,这种直接的vld1q_f32
可能会导致大量的cache miss,二是既然已经按行分别处理,且每一行的写入过程相互独立,所以可以考虑增加多线程来提高行间效率,可以使用Openmp或者自己起两个Thread来进行并行处理。