在Pytorch中實現im2col操作 Implementing im2col in Pytorch

Pytorch中可以用torch.unfold, torch.cattorch.transpose的組合實現im2col操作.

TAKE AWAY:

stride = (1, 1)
kernel_size = (3, 3)

x = torch.arange(0, 25).resize_(5, 5)

y = torch.cat(torch.cat(x.unfold(0, kernel_size[0], stride[0]).unfold(1, kernel_size[1], stride[1]).transpose(0, 2), dim=2).transpose(0, 1), dim=0)

下面以一個簡單小矩陣舉例詳細說明單通道im2col操作:

x = torch.arange(0, 25).resize_(5, 5)
print(x)

  0   1   2   3   4
  5   6   7   8   9
 10  11  12  13  14
 15  16  17  18  19
 20  21  22  23  24
[torch.FloatTensor of size 5x5]

定義卷積核大小和步長

kernel_size = (3, 3)
stride = (1, 1)

首先使用unfold將其切片成小矩陣, 先橫着切:

x = x.unfold(0, kernel, 1)
print(x)

(0 ,.,.) = 
   0   5  10
   1   6  11
   2   7  12
   3   8  13
   4   9  14

(1 ,.,.) = 
   5  10  15
   6  11  16
   7  12  17
   8  13  18
   9  14  19

(2 ,.,.) = 
  10  15  20
  11  16  21
  12  17  22
  13  18  23
  14  19  24
[torch.FloatTensor of size 3x5x3]

再豎着切:

x = x.unfold(1, kernel_size[1], stride[1])
print(x)

(0 ,0 ,.,.) = 
   0   1   2
   5   6   7
  10  11  12

(0 ,1 ,.,.) = 
   1   2   3
   6   7   8
  11  12  13

(0 ,2 ,.,.) = 
   2   3   4
   7   8   9
  12  13  14

(1 ,0 ,.,.) = 
   5   6   7
  10  11  12
  15  16  17

(1 ,1 ,.,.) = 
   6   7   8
  11  12  13
  16  17  18

(1 ,2 ,.,.) = 
   7   8   9
  12  13  14
  17  18  19

(2 ,0 ,.,.) = 
  10  11  12
  15  16  17
  20  21  22

(2 ,1 ,.,.) = 
  11  12  13
  16  17  18
  21  22  23

(2 ,2 ,.,.) = 
  12  13  14
  17  18  19
  22  23  24
[torch.FloatTensor of size 3x3x3x3]

這裏要注意, 因爲接下來要使用torch.cat做拼接, 但是因爲cat操作的一些特點, 需要先用transpose對維度順序做一下調整, 注意在我這個例子裏維度都是3所以可能看不出來, 可以自己做實驗試一下維度不相同的情況:

x = x.transpose(0, 2)
(0 ,0 ,.,.) = 
   0   1   2
   5   6   7
  10  11  12

(0 ,1 ,.,.) = 
   1   2   3
   6   7   8
  11  12  13

(0 ,2 ,.,.) = 
   2   3   4
   7   8   9
  12  13  14

(1 ,0 ,.,.) = 
   5   6   7
  10  11  12
  15  16  17

(1 ,1 ,.,.) = 
   6   7   8
  11  12  13
  16  17  18

(1 ,2 ,.,.) = 
   7   8   9
  12  13  14
  17  18  19

(2 ,0 ,.,.) = 
  10  11  12
  15  16  17
  20  21  22

(2 ,1 ,.,.) = 
  11  12  13
  16  17  18
  21  22  23

(2 ,2 ,.,.) = 
  12  13  14
  17  18  19
  22  23  24
[torch.FloatTensor of size 3x3x3x3]

然後用cat拼接一下:

x = torch.cat(x, dim=2)
print(x)

(0 ,.,.) = 
   0   1   2   5   6   7  10  11  12
   5   6   7  10  11  12  15  16  17
  10  11  12  15  16  17  20  21  22

(1 ,.,.) = 
   1   2   3   6   7   8  11  12  13
   6   7   8  11  12  13  16  17  18
  11  12  13  16  17  18  21  22  23

(2 ,.,.) = 
   2   3   4   7   8   9  12  13  14
   7   8   9  12  13  14  17  18  19
  12  13  14  17  18  19  22  23  24
[torch.FloatTensor of size 3x3x9]

這時再用transpose先轉置一下:

x = x.transpose(0, 1)
print(x)

(0 ,.,.) = 
   0   1   2   5   6   7  10  11  12
   1   2   3   6   7   8  11  12  13
   2   3   4   7   8   9  12  13  14

(1 ,.,.) = 
   5   6   7  10  11  12  15  16  17
   6   7   8  11  12  13  16  17  18
   7   8   9  12  13  14  17  18  19

(2 ,.,.) = 
  10  11  12  15  16  17  20  21  22
  11  12  13  16  17  18  21  22  23
  12  13  14  17  18  19  22  23  24
[torch.FloatTensor of size 3x3x9]

最後cat一次就完成啦:

x = torch.cat(x, dim=2)
print(x)

    0     1     2     5     6     7    10    11    12
    1     2     3     6     7     8    11    12    13
    2     3     4     7     8     9    12    13    14
    5     6     7    10    11    12    15    16    17
    6     7     8    11    12    13    16    17    18
    7     8     9    12    13    14    17    18    19
   10    11    12    15    16    17    20    21    22
   11    12    13    16    17    18    21    22    23
   12    13    14    17    18    19    22    23    24
[torch.FloatTensor of size 9x9]
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章