動機
想找到模型中所有的batchnormal 層並固定梯度,但是找了好久沒有合適的方法,現在記錄在下面
注意:令require_grad=False 無效
以下針對模型在訓練的模式下,測試的話就沒必要了,直接 model.eval() 即可
方法一
model.train()
for m in model.modules():
if isinstance(m,nn.BatchNorm2d):
m.eval()
方法二
def fix_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
model.train()
model.apply(fix_bn)