1 輸入:
import torch
the_model = torch.load("model.pt")
print (the_model)
params=the_model.state_dict()
len(params)
for i, j in enumerate(params):
print i, j
2 輸出:
RNNModel(
(drop): Dropout(p=0.2)
(encoder): Embedding(33278, 200)
(rnn): LSTM(200, 200, num_layers=2, dropout=0.2)
(decoder): Linear(in_features=200, out_features=33278, bias=True)
)
0 encoder.weight
1 rnn.weight_ih_l0
2 rnn.weight_hh_l0
3 rnn.bias_ih_l0
4 rnn.bias_hh_l0
5 rnn.weight_ih_l1
6 rnn.weight_hh_l1
7 rnn.bias_ih_l1
8 rnn.bias_hh_l1
9 decoder.weight
10 decoder.bias
3 當然,使用model.named_parameters()也能打印出網絡參數和名稱,有興趣的可以試試啊。