Pytorch模型訓練相關函數記錄

1 .訓練模型時,對梯度進行截斷

 import torch as t
  _ = nn.utils.clip_grad_norm_(model.parameters(), clip)

2 .load預訓練好的模型
cpu->cpu或者gpu->gpu,直接

model.load_state_dict(t.load(param_file))  # load模型

gpu上訓練,在cpu上載入

model.load_state_dict(t.load(params_file, map_location='cpu'))

3.學習率衰減
每5個epoch衰減爲原來的0.1

optimizer_f = optim.SGD(f.parameters(), lr=lr)
schedulerF = optim.lr_scheduler.StepLR(optimizer_f, step_size=5, gamma=0.1)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章