如何對batch的數據求Gram矩陣

Gram矩陣概念和理解

在風格遷移中,我們要比較生成圖片和風格圖片的相似性,評判標準就是通過計算Gram矩陣得到的。關於Gram矩陣的定義,可以參考[1]。

由這個矩陣的樣子,很容易就想到協方差矩陣。如果協方差矩陣是什麼忘了的化可以參考[2],可以看到Gram矩陣是沒有減去均值的協方差矩陣。協方差矩陣是一種相關性度量的矩陣,通過協方差來度量相關性,也就是度量兩個圖片風格的相似性。(如果相對協方差和相關係數有進一步瞭解,可以參考[3])

如何通過代碼實現Gram矩陣計算

瞭解Gram矩陣的概念和性質 ,我們就來看一看如何用代碼來實現Gram矩陣的計算。這裏,使用PyTorch來實現計算過程。

PyTorch中有兩個函數torch.mmtorch.bmm前者是計算矩陣乘法,後者是計算batch數據的矩陣乘法,風格遷移中是對batch數據進行操作,所以使用bmm。

我們創造一個batch爲2,單通道,2*2大小的數據

a = torch.arange(8, dtype=torch.int).reshape(2, 1, 2, 2)
a
>>> tensor([[[[0, 1],
          [2, 3]]],


        [[[4, 5],
          [6, 7]]]], dtype=torch.int32)

之後從新reshape一下,將w和h通道的數據合起來,變成向量形式

features = a.view(2, 1, 4)
features
>>>	tensor([[[0, 1, 2, 3]],

        [[4, 5, 6, 7]]], dtype=torch.int32)

爲了構造計算Gram矩陣的向量,對shape進行一個交換操作

features_t = features.transpose(1, 2)
features_t
>>>	tensor([[[0],
         [1],
         [2],
         [3]],

        [[4],
         [5],
         [6],
         [7]]], dtype=torch.int32)

之後用矩陣乘法把這兩個向量乘起來就可以了,就計算出Gram矩陣了。

gram = features.bmm(features_t)
gram
>>>	tensor([[[ 14]],

        [[126]]], dtype=torch.int32)

Reference

[1]Gram格拉姆矩陣在風格遷移中的應用
[2]如何直觀地理解「協方差矩陣」
[3]如何通俗易懂地解釋「協方差」與「相關係數」的概念?

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章