使用符號:
輸入尺寸(input): i
卷積核大小(kernel size): k
步幅(stride): s
邊界擴充(padding): p
輸出尺寸(output): o
卷積公式
沒有padding,且s=1
公式1:對於任意的i和k,如果s=1,p=0,則
o=(i−k)+1
有padding,且s=1
公式2:對於任意的i和k,p,如果s=1,則
o=(i−k)+2p+1
Half (same) padding
公式3:如果我們希望輸出的大小等於輸入的大小,那麼首先保證,k是奇數,於是對於任意的i,對於奇數的k=2n+1,s=1,p=⌊k/2⌋=n,於是
o=i+2⌊k/2⌋−(k−1)=i+2n−2n=i
這也是爲什麼,參數組合k=3,s=1,p=1以及k=5,s=1,p=2那麼常見的原因,他們不會改變output的大小。
Full padding
這是一種讓input size增加的padding設置方式:
公式4:對於任意的i,k,並設p=k−1,s=1則
o=i+2(k−1)−(k−1)=i+(k−1)
沒有padding,s>1
上面都是討論s=1的情況,現在討論s>1的情況,首先沒有padding的話公式是:
公式5:對於任意的i,k,s,若p=0,則
o=⌊si−k⌋+1
有padding,s>1
現在如果有padding:
公式6:對於任意的i,k,s,p
o=⌊si+2p−k⌋+1
可以看到,s的增加會使得i成倍地減少,如果想要讓o=i/2,一個最常用的配置是設s=2,然後−2≤2p−k≤−1,也就是s=2,p=1,k=4,或者s=2,p=1,k=3都可以
Pooling 公式
pooling其實只是一種特別的卷積核,所以他的計算公式跟卷積是一模一樣的,而且由於pooling是沒有padding的,所以他的計算公式就是:
公式7:對於任意的i,k,s
o=⌊si−k⌋+1
反捲積公式
沒有padding,且s=1
公式8如果正向卷積對於任意的k,且s=1,p=0,那麼如果其反捲積的設置爲k′=k,s′=s,p′=k−1,則反捲積的輸出大小爲:
o′=i′+(k−1)
顯然,這個跟公式1是一一對應的,(根據公式1,可以推出i=o−1+k)
有padding,且s=1
公式9:如果正向卷積對於任意的k,p,且s=1,那麼如果其反捲積的設置爲k′=k,s′=s,p′=k−p−1,則反捲積的輸出大小爲:
o′=i′+(k−1)−2p
類似的,這個公式是跟公式2一一對應的。顯然,當k=3,s=1,p=1時,其反捲積的參數恰好也是k′=3,s′=1,p′=3−1−1=1,是一模一樣的,另外一個常用的配置是,k=5,s=1,p=3,此時,反捲積的參數也是跟正向卷積一樣的。
Half (same) padding
公式10: 如果正向卷積對於任意的k=2n+1爲奇數,且s=1,p=⌊k/2⌋=n,那麼如果其反捲積的設置爲k′=k,s′=s,p′=p,則反捲積的輸出大小爲:
o′=i′+(k−1)−2p=i′+2n−2n=i′
這個公式是跟公式3是一一對應的,反捲積同意也能得到相同大小的output與input。
Full padding
這是對應於公式4的反捲積,將input減少的
公式11: 如果正向卷積對於任意的k,且s=1,p=k−1,那麼如果其反捲積的設置爲k′=k,s′=s,p′=0,則反捲積的輸出大小爲:
o′=i′+(k−1)−2p=i′−(k−1)
沒有padding,且s>1
在反捲積中,如果s>1,那麼它在像素間插入空白的間隔,如下圖所示:
經過擴大的圖大小變成了
i′^=i+(s−1)(i−1)
每一塊輸入之間都插入了(s-1)個空白點。經過插入後的大小記爲i′^
公式12: 如果正向卷積對於任意的k,s,且p=0,以及i−k是s的整數倍,那麼如果其反捲積,將想原始圖像的輸入拓展成i′^=i+(s−1)(i−1),然後設置爲k′=k,s′=1,p′=k−1,則反捲積的輸出大小爲:
o′=s(i′−1)+k
有padding,且s>1
公式13: 如果正向卷積對於任意的k,s,p,以及i+2p−k是s的整數倍,那麼如果其反捲積,將想原始圖像的輸入拓展成i′^=i+(s−1)(i−1),然後設置爲k′=k,s′=1,p′=k−p−1,則反捲積的輸出大小爲:
o′=s(i′−1)+k−2p
在上面爲了方便計算,都是假設是整數倍,如果沒有這個假設,那麼:
公式14: 如果正向卷積對於任意的k,s,p,,那麼如果其反捲積,將想原始圖像的輸入拓展成i′^=i+(s−1)(i−1),然後設置爲k′=k,s′=1,p′=k−1,則反捲積的輸出大小爲:
o′=s(i′−1)+a+k−2p
其中a=(i+2p−k)mods.
暴力測試參數數量
說實話。。最方便的方法還是直接寫代碼測測維度大小:
import torch
import torch.nn as nn
def paras_cnn(k,s,p,i=64):
x=torch.ones(1,1,i,i)
conv = torch.nn.Conv2d(1, 1, kernel_size=k, stride=s, padding=p)
convt= torch.nn.ConvTranspose2d(1, 1, kernel_size=k, stride=s, padding=p)
h1=conv(x)
h2=convt(x)
y=convt(h1)
print("conv(x):{} \t convT(x):{} \t convT(conv(x)):{}".format((h1.shape[2],h1.shape[3]),(h2.shape[2],h2.shape[3]),(y.shape[2],y.shape[3])))
return h1.shape[2],h1.shape[3],h2.shape[2],h2.shape[3],y.shape[2],y.shape[3]
參考資料
https://arxiv.org/abs/1603.07285
https://zhuanlan.zhihu.com/p/57348649