pytorch中的torch.nn.Unfold和torch.nn.Fold

1. torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

torch.nn.Unfold按照官方的說法,既從一個batch的樣本中,提取出滑動的局部區域塊,也就是卷積操作中的提取kernel filter對應的滑動窗口。
1)由上可知,torch.nn.Unfold的參數跟nn.Conv2d的參數很相似,即,kernel_size(卷積核的尺寸),dilation(空洞大小),padding(填充大小)和stride(步長)。
2)官方解釋中:unfold的輸入爲( N, C, H, W),其中N爲batch_size,C是channel個數,H和W分別是channel的長寬。則unfold的輸出爲( N, C × π ( k e r n e l _ s i z e ) C\times\pi (kernel\_size) C×π(kernel_size), L),其中 π ( k e r n e l _ s i z e ) \pi (kernel \_size) π(kernel_size)爲kernel_size長和寬的乘積, L是channel的長寬根據kernel_size的長寬滑動裁剪後,得到的區塊的數量。
3)例如:輸入(1, 2, 4, 4),假設kernel_size = (2, 2),stride = 2,根據官方給出的L計算公式
在這裏插入圖片描述
其中d是channel的維度,二維圖像既長寬的維度。則得到L(區塊數量)爲 :
在這裏插入圖片描述每個區塊的大小爲 C × \times × kernel_size[ 0 ] × \times × kernel_size[ 1 ] ,既 2 × 2 × 2 = 8,作爲輸出的第二個維度。
4) 代碼展示:






inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)

在這裏插入圖片描述
5)對代碼結果分析,nn.Unfold對輸入channel的每一個kernel_size[ 0 ] × \times × kernel_size[ 1 ] 的滑動窗口區塊做了展平操作。
在這裏插入圖片描述

2. torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)

torch.nn.Fold的操作與Unfold相反,將提取出的滑動局部區域塊還原成batch的張量形式。
1)代碼如下:

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())

2)代碼分析:Fold的操作通過設定output_size=(4, 4),完成與Unfold的互逆的操作。
在這裏插入圖片描述
借鑑博文:https://blog.csdn.net/weixin_44076434/article/details/106545037?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromBaidu-1.control&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromBaidu-1.control

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章