pytorch之打印網絡節點

1 輸入:

import torch

the_model = torch.load("model.pt")
print (the_model)

params=the_model.state_dict()
len(params)
for i, j in enumerate(params):
  print i, j

2 輸出: 

RNNModel(
  (drop): Dropout(p=0.2)
  (encoder): Embedding(33278, 200)
  (rnn): LSTM(200, 200, num_layers=2, dropout=0.2)
  (decoder): Linear(in_features=200, out_features=33278, bias=True)
)

0 encoder.weight
1 rnn.weight_ih_l0
2 rnn.weight_hh_l0
3 rnn.bias_ih_l0
4 rnn.bias_hh_l0
5 rnn.weight_ih_l1
6 rnn.weight_hh_l1
7 rnn.bias_ih_l1
8 rnn.bias_hh_l1
9 decoder.weight
10 decoder.bias

3 當然,使用model.named_parameters()也能打印出網絡參數和名稱,有興趣的可以試試啊。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章