使用符号:
输入尺寸(input): i
卷积核大小(kernel size): k
步幅(stride): s
边界扩充(padding): p
输出尺寸(output): o
卷积公式
没有padding,且s=1
公式1:对于任意的i和k,如果s=1,p=0,则
o=(i−k)+1
有padding,且s=1
公式2:对于任意的i和k,p,如果s=1,则
o=(i−k)+2p+1
Half (same) padding
公式3:如果我们希望输出的大小等于输入的大小,那么首先保证,k是奇数,于是对于任意的i,对于奇数的k=2n+1,s=1,p=⌊k/2⌋=n,于是
o=i+2⌊k/2⌋−(k−1)=i+2n−2n=i
这也是为什么,参数组合k=3,s=1,p=1以及k=5,s=1,p=2那么常见的原因,他们不会改变output的大小。
Full padding
这是一种让input size增加的padding设置方式:
公式4:对于任意的i,k,并设p=k−1,s=1则
o=i+2(k−1)−(k−1)=i+(k−1)
没有padding,s>1
上面都是讨论s=1的情况,现在讨论s>1的情况,首先没有padding的话公式是:
公式5:对于任意的i,k,s,若p=0,则
o=⌊si−k⌋+1
有padding,s>1
现在如果有padding:
公式6:对于任意的i,k,s,p
o=⌊si+2p−k⌋+1
可以看到,s的增加会使得i成倍地减少,如果想要让o=i/2,一个最常用的配置是设s=2,然后−2≤2p−k≤−1,也就是s=2,p=1,k=4,或者s=2,p=1,k=3都可以
Pooling 公式
pooling其实只是一种特别的卷积核,所以他的计算公式跟卷积是一模一样的,而且由于pooling是没有padding的,所以他的计算公式就是:
公式7:对于任意的i,k,s
o=⌊si−k⌋+1
反卷积公式
没有padding,且s=1
公式8如果正向卷积对于任意的k,且s=1,p=0,那么如果其反卷积的设置为k′=k,s′=s,p′=k−1,则反卷积的输出大小为:
o′=i′+(k−1)
显然,这个跟公式1是一一对应的,(根据公式1,可以推出i=o−1+k)
有padding,且s=1
公式9:如果正向卷积对于任意的k,p,且s=1,那么如果其反卷积的设置为k′=k,s′=s,p′=k−p−1,则反卷积的输出大小为:
o′=i′+(k−1)−2p
类似的,这个公式是跟公式2一一对应的。显然,当k=3,s=1,p=1时,其反卷积的参数恰好也是k′=3,s′=1,p′=3−1−1=1,是一模一样的,另外一个常用的配置是,k=5,s=1,p=3,此时,反卷积的参数也是跟正向卷积一样的。
Half (same) padding
公式10: 如果正向卷积对于任意的k=2n+1为奇数,且s=1,p=⌊k/2⌋=n,那么如果其反卷积的设置为k′=k,s′=s,p′=p,则反卷积的输出大小为:
o′=i′+(k−1)−2p=i′+2n−2n=i′
这个公式是跟公式3是一一对应的,反卷积同意也能得到相同大小的output与input。
Full padding
这是对应于公式4的反卷积,将input减少的
公式11: 如果正向卷积对于任意的k,且s=1,p=k−1,那么如果其反卷积的设置为k′=k,s′=s,p′=0,则反卷积的输出大小为:
o′=i′+(k−1)−2p=i′−(k−1)
没有padding,且s>1
在反卷积中,如果s>1,那么它在像素间插入空白的间隔,如下图所示:
经过扩大的图大小变成了
i′^=i+(s−1)(i−1)
每一块输入之间都插入了(s-1)个空白点。经过插入后的大小记为i′^
公式12: 如果正向卷积对于任意的k,s,且p=0,以及i−k是s的整数倍,那么如果其反卷积,将想原始图像的输入拓展成i′^=i+(s−1)(i−1),然后设置为k′=k,s′=1,p′=k−1,则反卷积的输出大小为:
o′=s(i′−1)+k
有padding,且s>1
公式13: 如果正向卷积对于任意的k,s,p,以及i+2p−k是s的整数倍,那么如果其反卷积,将想原始图像的输入拓展成i′^=i+(s−1)(i−1),然后设置为k′=k,s′=1,p′=k−p−1,则反卷积的输出大小为:
o′=s(i′−1)+k−2p
在上面为了方便计算,都是假设是整数倍,如果没有这个假设,那么:
公式14: 如果正向卷积对于任意的k,s,p,,那么如果其反卷积,将想原始图像的输入拓展成i′^=i+(s−1)(i−1),然后设置为k′=k,s′=1,p′=k−1,则反卷积的输出大小为:
o′=s(i′−1)+a+k−2p
其中a=(i+2p−k)mods.
暴力测试参数数量
说实话。。最方便的方法还是直接写代码测测维度大小:
import torch
import torch.nn as nn
def paras_cnn(k,s,p,i=64):
x=torch.ones(1,1,i,i)
conv = torch.nn.Conv2d(1, 1, kernel_size=k, stride=s, padding=p)
convt= torch.nn.ConvTranspose2d(1, 1, kernel_size=k, stride=s, padding=p)
h1=conv(x)
h2=convt(x)
y=convt(h1)
print("conv(x):{} \t convT(x):{} \t convT(conv(x)):{}".format((h1.shape[2],h1.shape[3]),(h2.shape[2],h2.shape[3]),(y.shape[2],y.shape[3])))
return h1.shape[2],h1.shape[3],h2.shape[2],h2.shape[3],y.shape[2],y.shape[3]
参考资料
https://arxiv.org/abs/1603.07285
https://zhuanlan.zhihu.com/p/57348649