Pytorch-nn.Conv2d中groups參數的理解

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

group參數的作用是控制分組卷積。

一句話,輸入in_channels被劃分爲groups組, 每組通道數爲in_channels/groups。每組需要重複計算out_channels/(n_channels/groups)次。

https://blog.csdn.net/ECNU_LZJ/article/details/105265843?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-1

中的代碼:

import torch
import torch.nn as nn

x = torch.Tensor([1, 10, 100, 1000, 10000, 100000]).view(1, -1, 1, 1)
print("x:", x.int())
conv = nn.Conv2d(
    in_channels=6, out_channels=9, kernel_size=1, stride=1, padding=0, groups=3, bias=False
)
print("Conv weight size:", conv.weight.data.size())
conv.weight.data = torch.arange(1, 19).float().view(9, 2, 1, 1)
print("Conv weight data:", conv.weight.data.int())
output = conv(x).int()
print("Output:", output)

如果是正常的卷積,參數大小應該爲: [9(輸出通道), 6(輸入通道), 1(核h), 1(核w)]。
這是因爲輸出是9個通道,每個通道都需要一個[6, 1, 1]大小的卷積(輸入的每個通道都參與到了運算)。
但是我們可以從代碼的運行結果中看到Conv層的參數大小爲: [9, 2, 1, 1]。這就說明對於每個輸出的通道,只有兩個輸入的通道參與了運算。
事實就是這樣,分組卷積的過程中只有部分輸入的通道才參與了運算。我們就以上面的代碼爲例進行講解。

首先將輸入的6個通道分爲3組: [1, 10], [100, 1000], [10000, 100000],每一組都用來生成輸出的一個通道。
3個組只能生成3個輸出通道,但是要求輸出是9個通道,所以每個組需要重複計算三次。
輸出的第1個通道: 1 * 1 + 2 * 10 = 21,需要用到輸入的第1組。
輸出的第2個通道: 3 * 1 + 4 * 10 = 43,需要用到輸入的第1組。
輸出的第3個通道: 5 * 1 + 6 * 10 = 65,需要用到輸入的第1組。
輸出的第4個通道: 7 * 100 + 8 * 1000 = 8700,需要用到輸入的第2組。
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章