通用矩陣乘的十種實現(x86平臺)

矩陣乘法的十種實現(x86版本)

前言

本文在intel平臺上對矩陣乘進行優化,主要依靠調整內存排布(for cache friendly)、SIMD(SSE)、多線程等方法。A,B,C矩陣大小分別爲MK,KN,MN。文中性能數據均爲M=N=K=1024下循環T次下的平均性能,完整代碼見最後。下面給出公式和示意圖,可以對照着理解代碼。
在這裏插入圖片描述
Cm,n=n=1KAm,kBk,nC_{m,n}=\sum_{n=1}^KA_{m,k}*B_{k,n}

v0.嚴格按照定義的實現

完全按照數學定義來實現,無任何優化。以此爲起點一步步嘗試開始優化。注意一下函數後綴MNK表示三層循環的順序。

void gemm_v0_MNK(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n++)
		{
			for (int k = 0; k < K; k++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
			}
		}
	}
	return;
}

在我的i5-7400上耗時4289ms

v1.調整循環順序

v0的實現中,最內存循環對B矩陣是按照列的方向訪存的,這樣在B矩陣的寬度較大時很容易cache miss。調整一下循環順序,對B矩陣按照行方向來訪問。

void gemm_v1_MKN(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m++)
	{
		for (int k = 0; k < K; k++)
		{
			float a = A[m*K + k];
			for (int n = 0; n < N; n++)
			{
				C[m*N + n] += a* B[k*N + n];
			}
		}
	}
	return;
}

調整順序後,最內層循環變爲一個某一個A的單值和B的一行相乘再加回對應的C的位置。可以看到C的同一個位置會被寫回K次。這個版本耗時1790ms

v2.對B矩陣進行轉置

v1中會有C多次寫回的問題。換一個思路,對B矩陣進行轉置這樣曾經對B讀取一列就變成了讀取一行。此時三層循環的順序爲MNK。

void transpose(const float *A, float *B, int M, int N)
{
	for (int n = 0; n < N; n++)
	{
		for (int m = 0; m < M; m++)
		{
			B[n*M + m] = A[N*m + n];
		}
	}
}

void gemm_v2_MNK_transposeB(const float *A, const float *B, float *C, int M, int N, int K)
{
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n++)
		{
			float sum = 0.0f;
			for (int k = 0; k < K; k++)
			{
				sum += A[m*K + k] * B[n*K + k];
			}
			C[m*N + n] = sum;
		}
	}
	return;
}

這一版本矩陣轉置加乘加的性能一共爲1620ms

v3.分塊矩陣乘

還是爲了優化cache命中率。我們把矩陣分成若干個小矩陣,小矩陣的尺寸是足夠被L1 cache緩存的。分塊的大小和具體機器cache大小有關,甚至和矩陣的規模有關。可以參考這篇博客。其實當問題無法完全建模時求各種限制條件下的最優解時,窮舉法暴力最優參數(在這個問題裏是分塊大小這個參數的選擇)也是種常見的優化方法。

inline void do_block(const float *A, const float *B, float *C, int K, int N, int BLOCKSIZE)
{
	
	for (int m = 0; m < BLOCKSIZE; m++)
	{
		for (int n = 0; n < BLOCKSIZE; n++)
		{
			float c = C[m*N + n];
			for (int k = 0; k < BLOCKSIZE; k++)
				c += A[m*K + k] * B[k*N + n];
			C[m*N + n] = c;
		}
	}
}


// 矩陣分塊乘法
void dgemm_block(const float *A, const float *B, float *C, int M, int N, int K)
{
	const int BLOCKSIZE = 64;
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m += BLOCKSIZE)
	{
		for (int n = 0; n < N; n += BLOCKSIZE)
		{
			for (int k = 0; k < K; k += BLOCKSIZE)
			{
				do_block(A + m*K + k, B + k*N + n, C + m*N + n, K, N, BLOCKSIZE);
			}
		}
	}
	return;

}

本版本性能爲1633ms。到這裏爲止都是對內存訪問進行優化,期望提升cache命中率。不過上面的幾份代碼效率都不太高,比如index的重複計算、沒有循環展開等。後面的代碼會把這些細節完善起來。分塊矩陣乘的實現也可以不用函數,減少出棧入棧的時間。

v4.SSE初步優化

v1中最內層循環中A矩陣中的一個值乘以B矩陣中的一整行。可以比較直觀的用向量化來實現。具體見代碼和註釋。

void gemm_v1_MKN_SSE(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	int m, n, k;
	for (m = 0; m < M; m++)
	{
		for (k = 0; k < K; k++)
		{
			__m128 v4_a = _mm_set1_ps(*(A + m*K + k));// Am,k Am,k Am,k Am,k
			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b = _mm_loadu_ps(B + k*N + n); // Bk,n Bk,n+1 Bk,n+2 Bk,n+3
				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a, v4_b)));
			}
			for (; n < N; n++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
			}
		}
	}
	return;
}

性能爲794ms,對比v11794ms差不多優化了一倍。僅僅簡單的應用下SSE對性能的提升也還是可觀的。

v5.循環展開(unroll)

爲了編譯器可以更好的排軟件流水,減少循環判斷次數,並且減少對C矩陣的寫回操作,我們在v4的基礎上改成最內層循環一次做4行。

void gemm_v1_MKN_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	int m, n, k;
	for (m = 0; m < M; m++)
	{
		for (k = 0; k < K - 3; k += 4)
		{
			__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));
			__m128 v4_a1 = _mm_set1_ps(*(A + m*K + k + 1));
			__m128 v4_a2 = _mm_set1_ps(*(A + m*K + k + 2));
			__m128 v4_a3 = _mm_set1_ps(*(A + m*K + k + 3));
			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b0 = _mm_loadu_ps(B + k*N + n);
				__m128 v4_b1 = _mm_loadu_ps(B + k*N + n + N);
				__m128 v4_b2 = _mm_loadu_ps(B + k*N + n + 2 * N);
				__m128 v4_b3 = _mm_loadu_ps(B + k*N + n + 3 * N);

				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b0));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a1, v4_b1));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a2, v4_b2));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a3, v4_b3));
				_mm_storeu_ps(C + m*N + n, v4_c);
			}
			for (; n < N; n++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
				C[m*N + n] += A[m*K + k + 1] * B[(k + 1)*N + n];
				C[m*N + n] += A[m*K + k + 2] * B[(k + 2)*N + n];
				C[m*N + n] += A[m*K + k + 3] * B[(k + 3)*N + n];
			}
		}
		for (; k < K; k++)
		{
			__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));

			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b = _mm_loadu_ps(B + k*N + n);
				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b)));
			}

			float a = A[m*K + k];
			for (; n < N; n++)
			{
				C[m*N + n] += a* B[k*N + n];
			}
		}
	}
	return;
}

實現的時候注意下要處理A矩陣高度不是4整除時的情況。可以看到**_mm_storeu_ps的指令減少到了原來的1/4。本版性能爲463ms**。又優化了一倍,繼續向下優化吧。

v6.在v2(轉置B矩陣)的版本上做SSE+UNROLL

雖然v5已經有了一定的性能提升,但是如之前分析的,這種計算流程會讓C有很多的寫回操作。再來看看轉置版本上SSE優化的結果吧。

void gemm_v2_MNK_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
	int k = 0, n = 0;
	__m128 v4_1_ps = _mm_set1_ps(1.0f);
	__m128 v4_sum_tmp_ps, v4_sumv_tmp_ps;
	for (int m = 0; m < M; m++)
	{
		for (n = 0; n < N - 3; n += 4)
		{
			float sum0, sum1, sum2, sum3;
			__m128 v4_sum0 = _mm_setzero_ps();
			__m128 v4_sum1 = _mm_setzero_ps();
			__m128 v4_sum2 = _mm_setzero_ps();
			__m128 v4_sum3 = _mm_setzero_ps();

			sum0 = sum1 = sum2 = sum3 = 0.0f;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 a = _mm_loadu_ps(A + m*K + k);

				__m128 b0 = _mm_loadu_ps(B + n*K + k);
				__m128 b1 = _mm_loadu_ps(B + n*K + k + K);
				__m128 b2 = _mm_loadu_ps(B + n*K + k + 2 * K);
				__m128 b3 = _mm_loadu_ps(B + n*K + k + 3 * K);

				v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
				v4_sum1 = _mm_add_ps(v4_sum1, _mm_mul_ps(a, b1));
				v4_sum2 = _mm_add_ps(v4_sum2, _mm_mul_ps(a, b2));
				v4_sum3 = _mm_add_ps(v4_sum3, _mm_mul_ps(a, b3));
			}
			for (; k < K; k++)
			{
				sum0 += A[m*K + k] * B[n*K + k];
				sum1 += A[m*K + k] * B[n*K + k + K];
				sum2 += A[m*K + k] * B[n*K + k + 2 * k];
				sum3 += A[m*K + k] * B[n*K + k + 3 * k];
			}
			v4_sum_tmp_ps = _mm_setr_ps(sum0, sum1, sum2, sum3);

			//v4_sumv_tmp_ps.m128_f32[0] = v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum0, v4_1_ps, 0xF1);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum1, v4_1_ps, 0xF2);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum2, v4_1_ps, 0xF4);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum3, v4_1_ps, 0xF8);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			_mm_storeu_ps(C + m*N + n, v4_sum_tmp_ps);
		}//end for n=0~N-3
		for (; n < N; n++)
		{
			float sum0;
			__m128 v4_sum0 = _mm_setzero_ps();
			sum0 = 0.0f;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 a = _mm_loadu_ps(A + m*K + k);
				__m128 b0 = _mm_loadu_ps(B + n*K + k);
				v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
			}
			for (; k < K; k++)
			{
				sum0 += A[m*K + k] * B[n*K + k];
			}
			C[m*N + n] = sum0 + v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
		}//end for n=N-3~N
	}// end for m
	return;
}

性能爲451ms。這份代碼有一個不舒服的地方在於,循環結束後需要把向量v4_sum_tmp_ps 中的四個通道值累加。這個操作效率很低(DSP平臺上可能更加明顯),大部分向量化(SIMD)優化的代碼都會避免這樣的操作。這裏的實現用了_mm_dp_ps點積指令來實現的,即向量乘以1後累加到mask指定的某個位置。

v7.以1x4小矩陣爲單位轉置B

這部分優化前,我們先重新定義一個矩陣轉置:以1x4爲基本單位進行轉置。
舉例:
下面代碼中數字表示一個1x4向量。比如原矩陣爲

1 5 
2 6
3 7
4 8

轉置後爲

1 2 3 4 
5 6 7 8

但是各個數字表示的向量中四個元素的排序還和原矩陣一致。在這種內存排佈下,我們用A中單值和B矩陣的一行相乘可得到四個結果。和v6相比,消除了向量內加法的操作。
也可以這麼理解,矩陣乘法中一次處理B矩陣的四列(這四列組成一個1x4的向量和A矩陣中的同一個單值相乘),所以我們轉置的時候以這四列爲單位進行轉置即可。顯而易見,這種轉置要求N爲4的倍數。

// 向量轉置vector4版本,注意轉置後矩陣寬高的變化
// M*N -> 1/4N*4M
void transpose_vec4(const float *A, float *B, int M, int N)
{
	int m, n;
	for (m = 0; m < M; m++)
	{
		for (n = 0; n < N; n += 4)
		{
			__m128 a = _mm_loadu_ps(A + m*N + n);
			_mm_storeu_ps(B + n*M + (m << 2), a);
		}
	}
}

// 4大小向量轉置B矩陣乘法
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4(const float *A, const float *B, float *C, int M, int N, int K)
{
	assert(0 == N % 4);
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum = _mm_set1_ps(0.0f);
			const float* pA = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);
				__m128 v4_a1 = _mm_load1_ps(pA + 1);
				__m128 v4_a2 = _mm_load1_ps(pA + 2);
				__m128 v4_a3 = _mm_load1_ps(pA + 3);

				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a1, v4_b1);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a2, v4_b2);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a3, v4_b3);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 1;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum);
		}
	}
	return;
}

我們對k做了展開,一次做四行。最終性能爲449ms

v8.使用omp做多線程

直接在v7版本上加上omp。

// 4大小向量轉置矩陣乘法+OMP
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
	assert(0 == N % 4);
#ifdef _OPENMP 
	omp_set_num_threads(4);
#pragma omp parallel for 
#endif
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum = _mm_set1_ps(0.0f);
			const float* pA = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{

				__m128 v4_a0 = _mm_load1_ps(pA);
				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				__m128 v4_a1 = _mm_load1_ps(pA + 1);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				v4_c = _mm_mul_ps(v4_a1, v4_b1);
				v4_sum = _mm_add_ps(v4_sum, v4_c);


				__m128 v4_a2 = _mm_load1_ps(pA + 2);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);

				v4_c = _mm_mul_ps(v4_a2, v4_b2);
				v4_sum = _mm_add_ps(v4_sum, v4_c);
				
				__m128 v4_a3 = _mm_load1_ps(pA + 3);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);
				v4_c = _mm_mul_ps(v4_a3, v4_b3);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 1;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum);
		}
	}
	return;
}

性能爲108ms

v9.更多的unroll

我們再試一下做更多的展開是否能進一步提升性能,相比v8版本這個版本又對m做了2爲單位的展開。

void gemm_v2_MNK_SSE_UNROLL2_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
#define CAL_ROWX(x) \
	v4_c = _mm_mul_ps(v4_a0, v4_b0); \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a1, v4_b1);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a2, v4_b2);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a3, v4_b3);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c);

	assert(0 == N % 4);
	int m = 0;
#ifdef _OPENMP 
	omp_set_num_threads(4);
#pragma omp parallel for lastprivate(m)
#endif
	for (m = 0; m < M - 1; m += 2)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum0 = _mm_set1_ps(0.0f);
			__m128 v4_sum1 = v4_sum0;
			const float* pA0 = A + m*K;
			const float* pA1 = A + m*K + K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 v4_c;
				// row0
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
				__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
				__m128 v4_a3 = _mm_load1_ps(pA0 + 3);


				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				CAL_ROWX(0)

					// row1
					v4_a0 = _mm_load1_ps(pA1);
				v4_a1 = _mm_load1_ps(pA1 + 1);
				v4_a2 = _mm_load1_ps(pA1 + 2);
				v4_a3 = _mm_load1_ps(pA1 + 3);

				CAL_ROWX(1)

					pA0 += 4;
				pA1 += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA1);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				// row0
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum0 = _mm_add_ps(v4_sum0, v4_c);

				// row1
				v4_c = _mm_mul_ps(v4_a1, v4_b0);
				v4_sum1 = _mm_add_ps(v4_sum1, v4_c);

				pA0++;
				pA1++;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum0);
			_mm_storeu_ps(C + m*N + N + n, v4_sum1);
		}
	}

	// m = M&(-1)
	for (; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum0 = _mm_set1_ps(0.0f);
			__m128 v4_c;
			const float* pA0 = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				// row0
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
				__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
				__m128 v4_a3 = _mm_load1_ps(pA0 + 3);


				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				CAL_ROWX(0)

					pA0 += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA0);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				// row0
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum0 = _mm_add_ps(v4_sum0, v4_c);

				pA0++;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum0);
		}
	}

	return;
}

這裏注意下對omp對m變量的私有聲明。不過不幸的是性能反而變差了,性能變成150ms。大部分情況下展開過多導致的負優化都是因爲寄存器溢出,這一點可以通過反彙編來驗證。
所以怎麼知道到底展開多少合適呢,因爲展開的時候大部分邏輯都是重複的,只是處理的數據不同。可以對變量命名做參數化,然後用這版本中類似CAL_ROWX(x)這樣的宏來總結這些重複代碼,甚至變量聲明也用宏來代替,宏的輸入爲行號列號等變量。
所以總結的夠好的話,展開幾次就是換成調用幾次宏而已。在此基礎上可進一步寫個自動生成優化代碼的程序,生成N份不同展開組合的實現,窮舉到底哪種展開最好(不過一般選出來的最優解比次優解不會強很多)。
缺點是:代碼可讀性下降,單步調試變難(宏內邏輯不能單步debug)。所以這裏v9我也就偷懶沒寫成全參數化的樣子了,畢竟還是以講原理爲目的。

v10.openBlas Gemm

還有第十一個版本,就是調用openblas的矩陣乘。結果發現比我們最快的版本還要快個2~3倍,大概4、50ms。我們這裏少了彙編級別調整,而且分塊矩陣乘法也沒合進最後的版本中去。

小結

本文利用粗淺的優化知識對矩陣乘法進行了一系列優化。後續可能在本篇基礎上再做一些精細的優化。若想學習生產環境中的優化代碼可以參考各大廠開源的DL推理庫。

reference

1.SSE指令查詢1
2.SSE指令查詢2

Code

完整代碼
note:在main.cpp中修改select數組的值來選擇測試哪些版本的矩陣乘。

本人水平有限,有理解不對的地方歡迎指正,共同進步。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章