本文主要講述 quantized 版本的矩陣乘法的計算流程,矩陣乘如下:
C = A × B C=A\times B C = A × B
其中 A A A 爲 uint8 類型, B B B 爲 int8 類型,C爲 uint8 類型。其詳細計算流程爲u8s8->s32. s32 需要 dequantize 成 fp32 類型, 然後再經過quantize 才能變成 uint8 類型。在 affine 量化中,s32 類型的 zero_point 爲 0.
A 矩陣採用 per tensor 量化, B 矩陣採用 per_channel 量化
A i i n t 8 = A i f p 32 s c a l e A + z e r o A A_i^{int8} = \frac{A_i^{fp32}}{scale_{A}} + zero_A A i i n t 8 = s c a l e A A i f p 3 2 + z e r o A
B i i n t 8 = B i f p 32 s c a l e B c o l j + z e r o B c o l j B_i^{int8} = \frac{B_i^{fp32}}{scale_{B_{col_j}}} + zero_{B_{col_j}} B i i n t 8 = s c a l e B c o l j B i f p 3 2 + z e r o B c o l j
A i i n t 8 × B i i n t 8 = ( A i f p 32 s c a l e A + z e r o A ) × ( B i f p 32 s c a l e B c o l j + z e r o B c o l j ) = ( A i f p 32 × B i f p 32 s c a l e A × s c a l e B c o l j ) + z e r o A × B i i n t 8 + A i i n t 8 × z e r o B c o l j − z e r o A × z e r o B c o l j = A i f p 32 × B i f p 32 + z e r o A × B i i n t 8 + A i i n t 8 × z e r o B c o l j − z e r o A × z e r o B c o l j
\begin{aligned}
A_i^{int8} \times B_i^{int8}
&= (\frac{A_i^{fp32}}{scale_{A}} + zero_A) \times (\frac{B_i^{fp32}}{scale_{B_{col_j}}} + zero_{B_{col_j}})\\
&=(\frac{A_i^{fp32} \times B_i^{fp32}}{scale_{A} \times scale_{B_{col_j}}} ) + zero_A \times B_i^{int8}+ \\
& \space \space \space \space \space A_i^{int8} \times zero_{B_{col_j}} - zero_A \times zero_{B_{col_j}}\\
&=A_i^{fp32} \times B_i^{fp32} + zero_A \times B_i^{int8}+ \\
& \space \space \space \space \space A_i^{int8} \times zero_{B_{col_j}} - zero_A \times zero_{B_{col_j}} \\
\end{aligned}
A i i n t 8 × B i i n t 8 = ( s c a l e A A i f p 3 2 + z e r o A ) × ( s c a l e B c o l j B i f p 3 2 + z e r o B c o l j ) = ( s c a l e A × s c a l e B c o l j A i f p 3 2 × B i f p 3 2 ) + z e r o A × B i i n t 8 + A i i n t 8 × z e r o B c o l j − z e r o A × z e r o B c o l j = A i f p 3 2 × B i f p 3 2 + z e r o A × B i i n t 8 + A i i n t 8 × z e r o B c o l j − z e r o A × z e r o B c o l j
∑ i = 0 k − 1 A i i n t 8 × B i i n t 8 = ∑ i = 0 k − 1 A i f p 32 × B i f p 32 + z e r o A × ∑ i = 0 k − 1 B i i n t 8 + z e r o B c o l j × ∑ i = 0 k − 1 A i i n t 8 − z e r o A × ∑ i = 0 k − 1 B c o l j = ∑ i = 0 k − 1 A i f p 32 × B i f p 32 + z e r o A × ( ∑ i = 0 k − 1 B i i n t 8 − ∑ i = 0 k − 1 B c o l j ) + z e r o B c o l j × ∑ i = 0 k − 1 A i i n t 8
\begin{aligned}
\sum _{i=0}^{k-1}{A_i^{int8}\times B_i^{int8} }
&=\sum _{i=0}^{k-1}{A_i^{fp32}\times B_i^{fp32} } + zero_A\times \sum _{i=0}^{k-1}{B_i^{int8}} + \\
& \space \space \space \space \space zero_{B_{col_j}}\times \sum _{i=0}^{k-1}{A_i^{int8}} - zero_A \times \sum _{i=0}^{k-1} B_{col_j}\\
&=\sum _{i=0}^{k-1}{A_i^{fp32}\times B_i^{fp32} } +zero_A \times(\sum _{i=0}^{k-1}{B_i^{int8}} - \sum _{i=0}^{k-1} B_{col_j} ) + \\
&\space \space \space \space \space zero_{B_{col_j}}\times \sum _{i=0}^{k-1}{A_i^{int8}}
\end{aligned}
i = 0 ∑ k − 1 A i i n t 8 × B i i n t 8 = i = 0 ∑ k − 1 A i f p 3 2 × B i f p 3 2 + z e r o A × i = 0 ∑ k − 1 B i i n t 8 + z e r o B c o l j × i = 0 ∑ k − 1 A i i n t 8 − z e r o A × i = 0 ∑ k − 1 B c o l j = i = 0 ∑ k − 1 A i f p 3 2 × B i f p 3 2 + z e r o A × ( i = 0 ∑ k − 1 B i i n t 8 − i = 0 ∑ k − 1 B c o l j ) + z e r o B c o l j × i = 0 ∑ k − 1 A i i n t 8
所以:
∑ i = 0 k − 1 A i f p 32 × B i f p 32 = ∑ i = 0 k − 1 A i i n t 8 × B i i n t 8 − z e r o B c o l j × ∑ i = 0 k − 1 A i i n t 8 − z e r o A × ( ∑ i = 0 k − 1 B i i n t 8 − ∑ i = 0 k − 1 B c o l j )
\begin{aligned}
\sum _{i=0}^{k-1}{A_i^{fp32}\times B_i^{fp32} }
&=\sum _{i=0}^{k-1}{A_i^{int8}\times B_i^{int8} } -
zero_{B_{col_j}}\times \sum _{i=0}^{k-1}{A_i^{int8}} - \\
& \space \space \space \space \space zero_A \times(\sum _{i=0}^{k-1}{B_i^{int8}} - \sum _{i=0}^{k-1} B_{col_j} )
\end{aligned}
i = 0 ∑ k − 1 A i f p 3 2 × B i f p 3 2 = i = 0 ∑ k − 1 A i i n t 8 × B i i n t 8 − z e r o B c o l j × i = 0 ∑ k − 1 A i i n t 8 − z e r o A × ( i = 0 ∑ k − 1 B i i n t 8 − i = 0 ∑ k − 1 B c o l j )
其中 ( ∑ i = 0 k − 1 B i i n t 8 − ∑ i = 0 k − 1 B c o l j ) (\sum _{i=0}^{k-1}{B_i^{int8}} - \sum _{i=0}^{k-1} B_{col_j} ) ( ∑ i = 0 k − 1 B i i n t 8 − ∑ i = 0 k − 1 B c o l j ) 爲第 i i i 列的col_offset, ∑ i = 0 k − 1 A i i n t 8 \sum _{i=0}^{k-1}{A_i^{int8}} ∑ i = 0 k − 1 A i i n t 8 爲 A 矩陣第 i i i 行的 row_offset