淺談torch.nn庫和torch.nn.functional庫(Pytorch)

淺談torch.nn庫和torch.nn.functional庫

這兩個庫很類似,都涵蓋了神經網絡的各層操作,只是用法有點不同,

nn下是類實現,nn.functional下是函數實現。

conv1d

  • 在nn下是一個類,一般繼承nn.module通過定義forward()函數計算其值
class Conv1d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _single(kernel_size)
        stride = _single(stride)
        padding = _single(padding)
        dilation = _single(dilation)
        super(Conv1d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _single(0), groups, bias)

    def forward(self, input):
        return torch.nn.functional.conv1d(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
  • 在nn.functional下直接傳入參數即可使用,其會直接返回一個torch.nn.functional的函數,和上面類中的forward()中的函數一致
def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
           groups=1):
    if input is not None and input.dim() != 3:
        raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))

    f = ConvNd(_single(stride), _single(padding), _single(dilation), False,
               _single(0), groups, torch.backends.cudnn.benchmark,
               torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
    return f(input, weight, bias)
  • nn.Xxx不需要自己定義和管理參數weight;而nn.functional.xxx需要自己定義weight,每次調用的時候都需要手動傳入weight

同樣droptout用nn定義的話在訓練時生效,在eval()時無效。

損失函數Loss(交叉熵)

  • nn庫

import torch
import torch.nn as nn

Loss = nn.BCELoss()

a = torch.ones(2,2)
b = torch.ones(2,2)
c = Loss(a,b)

  • nn.functional庫

import torch
import torch.nn.functional as nn

a = torch.ones(2,2)
b = torch.ones(2,2)
c = nn.binary_cross_entropy(a,b)

c的結果都一樣爲0,即兩個分佈高度相似

總結一下,兩個庫都可以實現神經網絡的各層運算。其他包括卷積、池化、padding、激活(非線性層)、線性層、正則化層、其他損失函數Loss,兩者都可以實現

nn.functional.xxx是函數接口,而nn.Xxx是nn.functional.xxx的類封裝,並且nn.Xxx都繼承於一個共同祖先nn.Module。因此nn.Xxx除了具有nn.functional.xxx功能(通過類中的forward方法實現),內部附帶了nn.Module相關的屬性和方法,例如train(), eval(),load_state_dict, state_dict 等,可以自動管理各層的參數,同時還可以實現如Sequential()將多個運算層組合爲一個邏輯層。

參考

https://pytorch.org/docs/stable/nn.html#

https://pytorch.org/docs/stable/nn.functional.html#

https://www.zhihu.com/question/66782101

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章