線性轉換:
舉例:
input1 = torch.randn(128, 20)
input2 = torch.randn(128, 3, 20) #中間 * 可以添加任意維度
input3 = torch.randn(128, 3, 4, 20)
m = nn.Linear(20, 30)
output1 = m(input1)
output2 = m(input2)
output3 = m(input3)
print(output1.size(), output2.size(), output3.size())
#
torch.Size([128, 30]) torch.Size([128, 3, 30]) torch.Size([128, 3, 4, 30])
中間 * 可以是任意維度,原理解釋:
input2 = torch.randn(128, 3, 20)
m = nn.Linear(20, 30)
output2 = m(input2)
input3 = input2.reshape(128 * 3, 20)
output3 = m(input3)
print(output3 == output2.reshape(128 * 3, -1))
#
tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], dtype=torch.uint8)
可見,將所有前面的維度相乘變爲了二維矩陣,nn.Linear() 線性變換,也就是全連接層的變換。