Pytoch 中 torch.nn 與 torch.nn.functional 的區別

在初學Pytorch 創建模型的時候,總會出現不知道要把layer放在 init() 中還是 forwad() 中,也不知道到底該使用nn.Conv2d 還是F.conv2d. 爲此帶來了不必要的煩惱.我爲了搞清用法查看了官方doc並在pytorch論壇上做了詢問,此爲討論的鏈接
整理結果如下:


torch.nn

torch.nn 這個模塊下面存的主要是 Module類.以torch.nn.Conv2d爲例, 也就是說 torch.nn.Conv2d這種"函數"其實是個 Module類,在實例化類後會初始化2d卷積所需要的參數. 這些參數會在你做forward和 backward之後根據loss進行更新,所以通常存放在定義模型的 _init_() 中.如:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        #其實這裏就是類的實例化,需要定義初始參數
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.act(self.conv1(x))
        return x

那在定義模型時,可不可以把nn.Conv2d寫在forward處?

  • 不可以
    如果寫成類似這樣會有什麼影響呢?
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.act = nn.ReLU()
        
    def forward(self, x):
        # 把卷積函數寫在forward中
        x= nn.Conv2d(3, 6, 3, 1, 1)(x)
        x = self.act(x)
        return x

把nn.Conv2d寫在forward中就相當於模型每次跑forward的時候,都重新實例化了nn.Conv2d和nn.Conv2d的參數,導致模型學不到參數.

torch.nn.functional

torch.nn.functional.x 爲函數,與torch.nn不同, torch.nn.x中包含了初始化需要的參數等 attributes 而torch.nn.functional.x則需要把相應的weights 作爲輸入參數傳遞,才能完成運算, 所以用torch.nn.functional創建模型時需要創建並初始化相應參數.
例如:

import torch.nn.functional as F
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.act = nn.ReLU()
        self.weighs = nn.Parameter(torch.rand(x,x,x,x))
        self.bias = nn.Parameter(torch.rand(x))
        
    def forward(self, x):
        # 把卷積函數寫在forward中,把w和b傳入函數
        x= F.conv2d(x,self.weighs,self.bias)
        x = self.act(x)
        return x

查看兩者的doc即可看出區別:

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor

CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’)

即一個側重數據結構,一個側重算法運算. 其實兩個都是完成了同樣的功能,只是實現方式有些不同而已

總結

torch.nn.X torch.nn.functional.X
是 類 是函數
結構中包含所需要初始化的參數 需要在函數外定義並初始化相應參數,並作爲參數傳入
一般情況下放在_init_ 中實例化,並在forward中完成操作 一般在_init_ 中初始化相應參數,在forward中傳入

所以 模型要麼寫成這樣

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # which has its own hidden parameters 
        self.conv_like = nn.convlike() 
        
    def forward(self, x):
        x = self.conv_like(x)

要麼寫成這樣:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # which will be used in nn.functional.funs 
        self.func_params = params 
        
    def forward(self, x):
        x = nn.functional.funs(x,self.func_params)
當所需的函數不含有需要學習的參數且在train和test階段運行方法一致時,無論用torch.nn.X 還是torch.nn.functional.X 都可以,並且既可以現在 init中聲明,也可以直接在forward中使用
  • 如: torch.nn.functional.relu, torch.nn.ReLU
train和test階段運行方法一致時,儘量用 torch.nn,避免了手動控制的麻煩
  • batch normalization 和 dropout
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章