nn.Linear()

  PyTorch的nn.Linear()是用於設置網絡中的全連接層的需要注意的是全連接層的輸入與輸出都是二維張量,一般形狀爲[batch_size, size],不同於卷積層要求輸入輸出是四維張量。其用法與形參說明如下:
在這裏插入圖片描述
  in_features指的是輸入的二維張量的大小,即輸入的[batch_size, size]中的size
  out_features指的是輸出的二維張量的大小,即輸出的二維張量的形狀爲[batch_size,output_size],當然,它也代表了該全連接層的神經元個數
  從輸入輸出的張量的shape角度來理解,相當於一個輸入爲[batch_size, in_features]的張量變換成了[batch_size, out_features]的輸出張量。
用法示例:

import torch as t
from torch import nn

# in_features由輸入張量的形狀決定,out_features則決定了輸出張量的形狀
connected_layer = nn.Linear(in_features = 64643, out_features = 1)

# 假定輸入的圖像形狀爲[64,64,3]
input = t.randn(1,64,64,3)

# 將四維張量轉換爲二維張量之後,才能作爲全連接層的輸入
input = input.view(1,64643)
print(input.shape)
output = connected_layer(input) # 調用全連接層
print(output.shape)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

這段代碼運行結果爲:

input shape is %s torch.Size([1, 12288])
output shape is %s torch.Size([1, 1])
  • 1
  • 2
                                </div>
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章