淺談torch.nn庫和torch.nn.functional庫
這兩個庫很類似,都涵蓋了神經網絡的各層操作,只是用法有點不同,
nn下是類實現,nn.functional下是函數實現。
conv1d
- 在nn下是一個類,一般繼承nn.module通過定義forward()函數計算其值
class Conv1d(_ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias)
def forward(self, input):
return torch.nn.functional.conv1d(input, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
- 在nn.functional下直接傳入參數即可使用,其會直接返回一個torch.nn.functional的函數,和上面類中的forward()中的函數一致
def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
groups=1):
if input is not None and input.dim() != 3:
raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))
f = ConvNd(_single(stride), _single(padding), _single(dilation), False,
_single(0), groups, torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
- nn.Xxx不需要自己定義和管理參數weight;而nn.functional.xxx需要自己定義weight,每次調用的時候都需要手動傳入weight
同樣droptout用nn定義的話在訓練時生效,在eval()時無效。
損失函數Loss(交叉熵)
- nn庫
import torch
import torch.nn as nn
Loss = nn.BCELoss()
a = torch.ones(2,2)
b = torch.ones(2,2)
c = Loss(a,b)
- nn.functional庫
import torch
import torch.nn.functional as nn
a = torch.ones(2,2)
b = torch.ones(2,2)
c = nn.binary_cross_entropy(a,b)
c的結果都一樣爲0,即兩個分佈高度相似
總結一下,兩個庫都可以實現神經網絡的各層運算。其他包括卷積、池化、padding、激活(非線性層)、線性層、正則化層、其他損失函數Loss,兩者都可以實現
nn.functional.xxx是函數接口,而nn.Xxx是nn.functional.xxx的類封裝,並且nn.Xxx都繼承於一個共同祖先nn.Module。因此nn.Xxx除了具有nn.functional.xxx功能(通過類中的forward方法實現),內部附帶了nn.Module相關的屬性和方法,例如train(), eval(),load_state_dict, state_dict 等,可以自動管理各層的參數,同時還可以實現如Sequential()將多個運算層組合爲一個邏輯層。
參考
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.functional.html#
https://www.zhihu.com/question/66782101