Kronecker Product及pytorch實現

Kronecker Product及pytorch實現

原始文檔:https://www.yuque.com/lart/idh721/gb2h93

計算過程

[a11a12a13a21a22a23],B=[b11b12b21b22b31b32],AB=[a11b11a11b12a12b11a12b12a13b11a13b12a11b21a11b22a12b21a12b22a13b21a13b22a11b31a11b32a12b31a12b32a13b31a13b32a21b11a21b12a22b11a22b12a23b11a23b12a21b21a21b22a22b21a22b22a23b21a23b22a21b31a21b32a22b31a22b32a23b31a23b32] \begin{bmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \end{bmatrix}, B =\begin{bmatrix} b_{11} & b_{12} \\ b_{21} & b_{22} \\ b_{31} & b_{32} \end{bmatrix}, \\ A \otimes B=\begin{bmatrix} a_{11}b_{11} & a_{11}b_{12} & a_{12}b_{11} & a_{12}b_{12} & a_{13}b_{11} & a_{13}b_{12} \\ a_{11}b_{21} & a_{11}b_{22} & a_{12}b_{21} & a_{12}b_{22} & a_{13}b_{21} & a_{13}b_{22} \\ a_{11}b_{31} & a_{11}b_{32} & a_{12}b_{31} & a_{12}b_{32} & a_{13}b_{31} & a_{13}b_{32} \\ a_{21}b_{11} & a_{21}b_{12} & a_{22}b_{11} & a_{22}b_{12} & a_{23}b_{11} & a_{23}b_{12} \\ a_{21}b_{21} & a_{21}b_{22} & a_{22}b_{21} & a_{22}b_{22} & a_{23}b_{21} & a_{23}b_{22} \\ a_{21}b_{31} & a_{21}b_{32} & a_{22}b_{31} & a_{22}b_{32} & a_{23}b_{31} & a_{23}b_{32} \end{bmatrix}

PyTorch實現

起因是看到了這篇文章:https://zhuanlan.zhihu.com/p/79295551介紹了一種新穎的卷積方式,其中使用了kronecker product方法來實現。這種計算理解很容易,但是實現起來該如何編程,這是一個值得思考的問題。文章結尾作者推薦的代碼中給出了一種實現https://github.com/d-li14/dgconv.pytorch/blob/master/dgconv.py#L26

def kronecker_product(mat1, mat2):
    out_mat = torch.ger(mat1.view(-1), mat2.view(-1))
    # 這裏的(mat1.size() + mat2.size())表示的是將兩個list拼接起來
    out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3])
    out_mat = out_mat.reshape(mat1.size(0) * mat2.size(0), mat1.size(1) * mat2.size(1))
    return out_mat 


這裏應該參考的是這裏的方法https://discuss.pytorch.org/t/kronecker-product/3919/7?u=i-love-u。但是該帖子後面給出了一種更簡單的方法https://discuss.pytorch.org/t/kronecker-product/3919/10

def kronecker(A, B):
    AB = torch.einsum("ab,cd->acbd", A, B)
    AB = AB.view(A.size(0)*B.size(0), A.size(1)*B.size(1))
    return AB

二者實際上是一致的,也就是說這裏的out_mat = torch.ger(mat1.view(-1), mat2.view(-1)), ``out_mat = out_mat.reshape(*(mat1.size() + mat2.size())).permute([0, 2, 1, 3]),與這裏的enisum AB = torch.einsum("ab,cd->acbd", A, B) 表示的是一樣的行爲。

在前者中,假設 A=mat1B=mat2 ,二者分別爲axb和cxd大小的矩陣。對於二者先通過矢量化後,利用外積操作計算出了各個元素之間的乘積構成的矩陣,大小爲abxcd,再將結果調整爲axbxcxd大小的形狀,利用permute操作調整後,變成了axcxbxd大小的形狀。這也正是符合einsum中的維度索引的調整 'ab, cd->acbd' 。殊途同歸。

簡單的解釋。我們最終的目標是得到這樣一個結果:
image.png

注意!這裏有下角的2x2方陣中的b的下標有誤,注意!

這裏我標註了A和B對應的下標的變化範圍。由於我們利用PyTorch實現這些處理,那我們必定是要使用對應的矩陣運算和形狀變換的。從這裏的圖可以看出來,若是對於這裏的結果直接使用 view 操作,那麼可以得到這樣的結果:
image.png

注意!這裏和前圖相同對應的位置上下標有誤,注意!

這實際上也就是索引爲 (i,m,j,n) 的2D矩陣。 也就是更直接的,對於 torch.einsum("ab,cd->acbd", A, B) 的結果的表示。所以也就出現了前面的兩種形式的代碼。

參考資料

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章