背景
CNN
是深度學習的重中之重,而conv1D
,conv2D
,和conv3D
又是CNN
的核心,所以理解conv
的工作原理就變得尤爲重要。在本博客中,將簡單梳理一下這三種卷積,以及在PyTorch中的應用方法。
參考
https://pytorch.org/docs/master/nn.html#conv1d
https://pytorch.org/docs/master/nn.functional.html#conv1d
文檔
本節的主要內容就是一邊看文檔,一邊用代碼驗證。在PyTorch
中,分別在torch.nn
和torch.nn.functional
兩個模塊都有conv1d
,conv2d
和conv3d
;從計算過程來說,兩者本身沒有太大區別;但是torch.nn
下的都是layer
,conv
的參數都是經過訓練得到;torch.nn.functional
下的都是函數,其參數可以人爲設置。本文在分析時,兩者的文檔一起看,但是實驗主要以torch.nn.functional
爲主,更加方便修改。
conv1d
由於conv
的參數都大同小異,但是conv1d更加方便理解(更容易可視化),所以我會話費大量時間詳細介紹此卷積方式。
torch.nn.functional.conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
input
input
是一維輸入,其形狀爲(Batch_Size,In_Channel,Length)
;Batch_Size
是訓練批量的大小;In_Channel
是輸入的通道數量;Length
是輸入的長度,因爲是一維輸入,所以其只有長度。下圖展示了一個一維輸入,其包括3個通道:
weight
weight是一維卷積核,其形狀爲(Out_Channel,In_Channel/Group,Kernel_Size)
;Out_Channel
是輸出的通道數量;In_Channel/Group的目
的是決定每一層的輸出是如何由輸入組成的,後續會詳細介紹,此處不妨設Group=1
;Kernel_Size
是一維卷積核的大小。下圖展示了一個一維卷積核和其對應的累加方式:
考慮到公式還是太難看懂,再畫一個圖展示上圖右邊的公式是怎麼計算的(並沒有畫bias):
group
groups
是卷積中一個非常特殊的參數,前邊已經提到,此處再詳細介紹;當Group=1
時,每一層輸出由所有輸入分別與卷積核卷積的累加得到,當Group=2
,每一層輸出僅由一般的輸入卷積累加得到(前提是,輸入通道數和輸出通道數都可以被Group
整除),當Group=In_Channel
時,每一層輸出由每一層輸入卷積得到,無需累加,換句話說,Group
的值會打斷輸入層之間的卷積關係。下圖左是是Group=1
,右圖是Group=3
,體會一下。
stride
stride
理解起來還是很容易;當stride=1
時,卷積核在原始輸入上以步長爲1
進行移動;當stride=2
時,卷積核就以不是爲2
進行移動;以此類推。下圖展示了不同的卷積stride
,其中紅色表示第一次卷積,紫色表示第二次卷積。
padding
padding
也是一個非常容易理解的概念,其主要用於處理卷積的邊界情況。對於torch.nn.Conv1d
而言,padding
有非常多的模式,什麼置爲0
,鏡像,複製等等;但是torch.nn.functional.Conv1d
就只有置爲0
。
dilation
dilation
,按照我的理解就是帶孔卷積,其控制輸入層上的取樣間隔,當dilation=1
時,就是前文所示的卷積。下圖展示了一個dilation=2
的情形,不難發現,這個參數能夠在不增強計算量的前提下增大感受野。
bias
bias
沒啥好說的。
conv2d&conv3d
原則上講,如果看文檔,會發現conv2d、conv3d和conv1d並沒有太大的區別,只不過在維度上有所區別。因此,我也就不分開介紹,直接放在一起。不難發現,唯一的區別在於維度的上升;因此爲weights的定義也有所不同,分別是(Out_Channels, Groups/In_channels,kH,kW)
和(Out_Channels, Groups/In_channels,kT,kH,kW)
。
不過這裏,我額外說一下,當輸入圖像有RGB三個通道時,似乎看起來conv2d
和conv3d
沒啥區別,反正都要對所有通道進行卷積。但是其實這裏有非常大的區別:
conv3d
除了可以在圖像平面上移動卷積之外,還可以在深度方向進行卷積;而conv2d並沒有這個能力。conv3d
中的深度和conv2d
中的Channel
是不對應的,conv3d
中的每一個深度上都可以對應多個Channel
(雖然圖不是這樣畫的),因此深度和Channel
是不同的概念。
代碼
下邊寫幾個測試代碼,並簡單說明一下。
conv1d
Batch_Size = 1
In_Channel = 2
Length = 7
Out_Channel = 2
Group = 1
Kernel_Size = 3
Padding = 1
Dilation = 1
one = torch.rand(Batch_Size,In_Channel,Length)
print('one',one)
# 定義了兩個Kernel
# 第一個Kernel取第一個Channel中間那個值
# 第二個Kernel將第一個Channel與第二個Channel相減
filter = torch.zeros(Out_Channel,int(In_Channel/Group),Kernel_Size)
filter[0][0][1] = 1
filter[1][0][1] = 1
filter[1][1][1] = -1
result = F.conv1d(one,filter,padding=Padding,groups=Group,dilation=Dilation)
print('result',result)
結果
one tensor([[[0.6465, 0.3762, 0.3227, 0.6881, 0.6364, 0.5725, 0.8627],
[0.9221, 0.7417, 0.3096, 0.1008, 0.8527, 0.4099, 0.4143]]])
result tensor([[[ 0.6465, 0.3762, 0.3227, 0.6881, 0.6364, 0.5725, 0.8627],
[-0.2756, -0.3655, 0.0131, 0.5873, -0.2163, 0.1626, 0.4485]]])
conv2d
Batch_Size = 1
In_Channel = 2
Height = 5
Width = 5
Out_Channel = 2
Group = 1
Kernel_Size_H = 3
Kernel_Size_W = 3
Padding = 1
Dilation = 1
# 定義了兩個Kernel
# 第一個Kernel取第一個Channel左上角的值
# 第二個Kernel取第二個Channel右下角的值
two = torch.rand(Batch_Size,In_Channel,Height,Width)
print('two',two)
filter = torch.zeros(Out_Channel,int(In_Channel/Group),Kernel_Size_H,Kernel_Size_W)
filter[0][0][0][0]=1
filter[1][1][2][2]=1
depth = F.conv2d(two,filter,padding=Padding)
print(depth.shape)
print(depth)
結果
two tensor([[[[0.6886, 0.5815, 0.2635, 0.5373, 0.2606],
[0.7335, 0.2440, 0.5123, 0.9990, 0.1864],
[0.5270, 0.1498, 0.0728, 0.1900, 0.0408],
[0.0819, 0.2725, 0.7476, 0.8551, 0.2504],
[0.2355, 0.5189, 0.7329, 0.8619, 0.3117]],
[[0.5712, 0.4581, 0.7050, 0.2502, 0.3364],
[0.1892, 0.6736, 0.3675, 0.2895, 0.8894],
[0.5782, 0.0020, 0.5400, 0.4404, 0.3508],
[0.3597, 0.1373, 0.0068, 0.0440, 0.9917],
[0.3296, 0.0371, 0.0367, 0.0597, 0.8797]]]])
torch.Size([1, 2, 5, 5])
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.6886, 0.5815, 0.2635, 0.5373],
[0.0000, 0.7335, 0.2440, 0.5123, 0.9990],
[0.0000, 0.5270, 0.1498, 0.0728, 0.1900],
[0.0000, 0.0819, 0.2725, 0.7476, 0.8551]],
[[0.6736, 0.3675, 0.2895, 0.8894, 0.0000],
[0.0020, 0.5400, 0.4404, 0.3508, 0.0000],
[0.1373, 0.0068, 0.0440, 0.9917, 0.0000],
[0.0371, 0.0367, 0.0597, 0.8797, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
conv3d
Batch_Size = 1
In_Channel = 2
Height = 5
Width = 5
Depth = 5
Out_Channel = 2
Group = 1
Kernel_Size_D = 3
Kernel_Size_H = 3
Kernel_Size_W = 3
Padding = 1
Dilation = 1
# 定義了兩個Kernel
# 第一個Kernel取第一個深度,左上角的值
# 第二個Kernel啥也不做
thr = torch.rand(Batch_Size,In_Channel,Depth,Height,Width)
print(thr)
filter = torch.zeros(Out_Channel,int(In_Channel/Group),Kernel_Size_D,Kernel_Size_H,Kernel_Size_W)
filter[0][0][0][0][0]=1
result = torch.conv3d(thr,filter,padding=1)
print(result)
結果
tensor([[[[[0.9226, 0.8931, 0.7071, 0.7718, 0.5866],
[0.1164, 0.8881, 0.5236, 0.7025, 0.1280],
[0.1002, 0.0013, 0.1704, 0.1424, 0.5018],
[0.8796, 0.3582, 0.2792, 0.7098, 0.9759],
[0.4871, 0.3776, 0.9242, 0.5693, 0.0594]],
[[0.7816, 0.8589, 0.4025, 0.0712, 0.4381],
[0.2501, 0.1536, 0.5014, 0.4333, 0.9369],
[0.9491, 0.8624, 0.4953, 0.6443, 0.4056],
[0.7834, 0.2791, 0.5448, 0.0204, 0.4199],
[0.1179, 0.0021, 0.3744, 0.6835, 0.4836]],
[[0.9522, 0.0417, 0.0653, 0.4445, 0.2879],
[0.2581, 0.8633, 0.2610, 0.9866, 0.9338],
[0.2689, 0.6511, 0.0543, 0.7373, 0.2599],
[0.7211, 0.9832, 0.9786, 0.3957, 0.2649],
[0.3640, 0.5514, 0.6898, 0.9033, 0.2067]],
[[0.5609, 0.7697, 0.0895, 0.1205, 0.2559],
[0.7284, 0.0997, 0.3773, 0.1338, 0.9526],
[0.1489, 0.0499, 0.6159, 0.9188, 0.9630],
[0.0550, 0.0325, 0.0619, 0.2393, 0.9781],
[0.6343, 0.4791, 0.6076, 0.7346, 0.1744]],
[[0.4132, 0.2946, 0.3903, 0.6658, 0.6961],
[0.7019, 0.1594, 0.6541, 0.5868, 0.0685],
[0.7312, 0.9089, 0.8287, 0.4644, 0.3078],
[0.7363, 0.2700, 0.7368, 0.8905, 0.2089],
[0.3708, 0.5744, 0.2688, 0.7639, 0.8681]]],
[[[0.7363, 0.4299, 0.6298, 0.6484, 0.5674],
[0.9055, 0.7832, 0.7443, 0.1624, 0.6099],
[0.8624, 0.1860, 0.2237, 0.3271, 0.5107],
[0.2373, 0.6254, 0.8148, 0.3317, 0.6703],
[0.8364, 0.2029, 0.2762, 0.4807, 0.6596]],
[[0.1022, 0.9687, 0.4097, 0.9130, 0.5343],
[0.3665, 0.0765, 0.0136, 0.6457, 0.5640],
[0.3436, 0.1625, 0.8261, 0.5664, 0.7331],
[0.4402, 0.8114, 0.4218, 0.5149, 0.3197],
[0.2731, 0.3032, 0.9294, 0.9505, 0.3776]],
[[0.2852, 0.0566, 0.5607, 0.0690, 0.6652],
[0.5315, 0.5046, 0.9546, 0.5480, 0.4868],
[0.5333, 0.7227, 0.0407, 0.6066, 0.6386],
[0.5846, 0.2641, 0.0451, 0.0521, 0.8822],
[0.8929, 0.2496, 0.5646, 0.3253, 0.8867]],
[[0.3010, 0.5833, 0.6355, 0.2783, 0.4770],
[0.6493, 0.2489, 0.9739, 0.8326, 0.7717],
[0.3469, 0.9503, 0.3222, 0.4197, 0.5231],
[0.2533, 0.4396, 0.8671, 0.6622, 0.3155],
[0.0444, 0.3937, 0.0983, 0.5874, 0.6237]],
[[0.8788, 0.4389, 0.2793, 0.9504, 0.5325],
[0.4858, 0.3797, 0.3282, 0.6697, 0.5938],
[0.8738, 0.4183, 0.1169, 0.2855, 0.2764],
[0.0590, 0.4542, 0.8047, 0.1575, 0.3735],
[0.2168, 0.4904, 0.1830, 0.2141, 0.4013]]]]])
tensor([[[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.9226, 0.8931, 0.7071, 0.7718],
[0.0000, 0.1164, 0.8881, 0.5236, 0.7025],
[0.0000, 0.1002, 0.0013, 0.1704, 0.1424],
[0.0000, 0.8796, 0.3582, 0.2792, 0.7098]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7816, 0.8589, 0.4025, 0.0712],
[0.0000, 0.2501, 0.1536, 0.5014, 0.4333],
[0.0000, 0.9491, 0.8624, 0.4953, 0.6443],
[0.0000, 0.7834, 0.2791, 0.5448, 0.0204]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.9522, 0.0417, 0.0653, 0.4445],
[0.0000, 0.2581, 0.8633, 0.2610, 0.9866],
[0.0000, 0.2689, 0.6511, 0.0543, 0.7373],
[0.0000, 0.7211, 0.9832, 0.9786, 0.3957]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5609, 0.7697, 0.0895, 0.1205],
[0.0000, 0.7284, 0.0997, 0.3773, 0.1338],
[0.0000, 0.1489, 0.0499, 0.6159, 0.9188],
[0.0000, 0.0550, 0.0325, 0.0619, 0.2393]]],
[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]])
總結
莫名其妙就寫了一大堆,也許還是不懂,但是跑跑代碼就明白了。