pytorch 中網絡參數 weight bias 初始化方法

權重初始化對於訓練神經網絡至關重要,好的初始化權重可以有效的避免梯度消失等問題的發生。

在pytorch的使用過程中有幾種權重初始化的方法供大家參考。
注意:第一種方法不推薦。儘量使用後兩種方法。

# not recommend

def weights_init(m):

classname = m.__class__.__name__

if classname.find('Conv') != -1:

m.weight.data.normal_(0.0, 0.02)

elif classname.find('BatchNorm') != -1:

m.weight.data.normal_(1.0, 0.02)

m.bias.data.fill_(0)

# recommend

def initialize_weights(m):

if isinstance(m, nn.Conv2d):

m.weight.data.normal_(0, 0.02)

m.bias.data.zero_()

elif isinstance(m, nn.Linear):

m.weight.data.normal_(0, 0.02)

m.bias.data.zero_()

# recommend

def weights_init(m):

if isinstance(m, nn.Conv2d):

nn.init.xavier_normal_(m.weight.data)

nn.init.xavier_normal_(m.bias.data)

elif isinstance(m, nn.BatchNorm2d):

nn.init.constant_(m.weight,1)

nn.init.constant_(m.bias, 0)

elif isinstance(m, nn.BatchNorm1d):

nn.init.constant_(m.weight,1)

nn.init.constant_(m.bias, 0)

編寫好weights_init函數後,可以使用模型的apply方法對模型進行權重初始化。
 

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章