PyTorch : nn.Linear() 詳解

線性轉換:
在這裏插入圖片描述
舉例:

input1 = torch.randn(128, 20)
input2 = torch.randn(128, 3, 20) #中間 * 可以添加任意維度
input3 = torch.randn(128, 3, 4, 20) 
m = nn.Linear(20, 30)
output1 = m(input1)
output2 = m(input2)
output3 = m(input3)
print(output1.size(), output2.size(), output3.size())
#
torch.Size([128, 30]) torch.Size([128, 3, 30]) torch.Size([128, 3, 4, 30])

中間 * 可以是任意維度,原理解釋:

input2 = torch.randn(128, 3, 20)

m = nn.Linear(20, 30)
output2 = m(input2)

input3 = input2.reshape(128 * 3, 20)
output3 = m(input3)

print(output3 == output2.reshape(128 * 3, -1))
#
tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.uint8)

可見,將所有前面的維度相乘變爲了二維矩陣,nn.Linear() 線性變換,也就是全連接層的變換。

PyTorch官方文檔

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章