【最佳實踐】pytorch模型權重的重置與重新賦值

重置爲原來的值:

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

model = = nn.Sequential(
    nn.Conv2d(3, 6, 3, 1, 1),
    nn.ReLU(),
    nn.Linear(20, 3)
)

model.apply(weight_reset)

參考鏈接:How to re-set alll parameters in a network

重新賦值爲指定值:

with torch.no_grad():
    for name, param in model.named_parameters():
        if 'classifier.weight' in name:
            param.copy_(torch.randn(10, 10))

不推薦直接使用.data屬性賦值,因爲直接賦值會使得該操作無法被類感知,可能會造成某種隱含的bug。

參考鏈接:How to assign an arbitrary tensor to model’s parameter?

發佈了137 篇原創文章 · 獲贊 85 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章