PyTorch的nn.Linear()
是用於設置網絡中的全連接層的,需要注意的是全連接層的輸入與輸出都是二維張量,一般形狀爲[batch_size, size]
,不同於卷積層要求輸入輸出是四維張量。其用法與形參說明如下:
in_features
指的是輸入的二維張量的大小,即輸入的[batch_size, size]
中的size
。
out_features
指的是輸出的二維張量的大小,即輸出的二維張量的形狀爲[batch_size,output_size]
,當然,它也代表了該全連接層的神經元個數。
從輸入輸出的張量的shape角度來理解,相當於一個輸入爲[batch_size, in_features]
的張量變換成了[batch_size, out_features]
的輸出張量。
用法示例:
import torch as t
from torch import nn
# in_features由輸入張量的形狀決定,out_features則決定了輸出張量的形狀
connected_layer = nn.Linear(in_features = 64643, out_features = 1)
# 假定輸入的圖像形狀爲[64,64,3]
input = t.randn(1,64,64,3)
# 將四維張量轉換爲二維張量之後,才能作爲全連接層的輸入
input = input.view(1,64643)
print(input.shape)
output = connected_layer(input) # 調用全連接層
print(output.shape)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
這段代碼運行結果爲:
input shape is %s torch.Size([1, 12288])
output shape is %s torch.Size([1, 1])
- 1
- 2
</div>