Pytorch下根據layers的name凍結層進行finetune訓練

進行深度學習任務,當數據量較少而使用的網絡模型很大時,很容易出現overfitting現象。
其中一個避免過擬合的方式就是finetune
這時候我們不需要對模型的所有參數進行訓練,只需要凍結某些層,訓練部分層就可以。
本文就凍結層訓練做簡單介紹和記錄。

首先介紹一下model.named_parameters()函數,直接上例子:

from ResNeSt.resnest.torch.resnest import resnest101,resnest200,resnest269

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        model = resnest269(pretrained=True)
        model.fc = nn.Linear(2048,102)
        self.resnet = model
        
    def forward(self, img):        
        out = self.resnet(img)
        # print('out is {}'.format(out))
        return out

model = Net().cuda()
for name, param in model.named_parameters():  # 查看可優化的參數有哪些
    if param.requires_grad:
        print(name)

上面是引入reanest101模型,注意是resnest101而不是resnet101,關於resnest的介紹見【Resnet最強變體】ResNeSt學習筆記 — ResNeSt: Split-Attention Networks

迭代打印出可以優化的layers的name,輸出如下:

resnet.conv1.0.weight
resnet.conv1.1.weight
resnet.conv1.1.bias
resnet.conv1.3.weight
resnet.conv1.4.weight
resnet.conv1.4.bias
resnet.conv1.6.weight
resnet.bn1.weight
resnet.bn1.bias
resnet.layer1.0.conv1.weight
resnet.layer1.0.bn1.weight
resnet.layer1.0.bn1.bias
resnet.layer1.0.conv2.conv.weight
resnet.layer1.0.conv2.bn0.weight
resnet.layer1.0.conv2.bn0.bias
resnet.layer1.0.conv2.fc1.weight
resnet.layer1.0.conv2.fc1.bias
resnet.layer1.0.conv2.bn1.weight
resnet.layer1.0.conv2.bn1.bias
resnet.layer1.0.conv2.fc2.weight
resnet.layer1.0.conv2.fc2.bias
resnet.layer1.0.conv3.weight
resnet.layer1.0.bn3.weight
resnet.layer1.0.bn3.bias
resnet.layer1.0.downsample.1.weight
resnet.layer1.0.downsample.2.weight
resnet.layer1.0.downsample.2.bias
resnet.layer1.1.conv1.weight
resnet.layer1.1.bn1.weight
resnet.layer1.1.bn1.bias
resnet.layer1.1.conv2.conv.weight
resnet.layer1.1.conv2.bn0.weight
resnet.layer1.1.conv2.bn0.bias
resnet.layer1.1.conv2.fc1.weight
resnet.layer1.1.conv2.fc1.bias
resnet.layer1.1.conv2.bn1.weight
resnet.layer1.1.conv2.bn1.bias
resnet.layer1.1.conv2.fc2.weight
resnet.layer1.1.conv2.fc2.bias
resnet.layer1.1.conv3.weight
resnet.layer1.1.bn3.weight
resnet.layer1.1.bn3.bias
resnet.layer1.2.conv1.weight
resnet.layer1.2.bn1.weight
resnet.layer1.2.bn1.bias
resnet.layer1.2.conv2.conv.weight
resnet.layer1.2.conv2.bn0.weight
resnet.layer1.2.conv2.bn0.bias
resnet.layer1.2.conv2.fc1.weight
resnet.layer1.2.conv2.fc1.bias
resnet.layer1.2.conv2.bn1.weight
resnet.layer1.2.conv2.bn1.bias
resnet.layer1.2.conv2.fc2.weight
resnet.layer1.2.conv2.fc2.bias
resnet.layer1.2.conv3.weight
resnet.layer1.2.bn3.weight
resnet.layer1.2.bn3.bias
resnet.layer2.0.conv1.weight
resnet.layer2.0.bn1.weight
resnet.layer2.0.bn1.bias
resnet.layer2.0.conv2.conv.weight
resnet.layer2.0.conv2.bn0.weight
resnet.layer2.0.conv2.bn0.bias
resnet.layer2.0.conv2.fc1.weight
resnet.layer2.0.conv2.fc1.bias
resnet.layer2.0.conv2.bn1.weight
resnet.layer2.0.conv2.bn1.bias
resnet.layer2.0.conv2.fc2.weight
resnet.layer2.0.conv2.fc2.bias
resnet.layer2.0.conv3.weight
resnet.layer2.0.bn3.weight
resnet.layer2.0.bn3.bias
resnet.layer2.0.downsample.1.weight
resnet.layer2.0.downsample.2.weight
resnet.layer2.0.downsample.2.bias
resnet.layer2.1.conv1.weight
resnet.layer2.1.bn1.weight
resnet.layer2.1.bn1.bias
resnet.layer2.1.conv2.conv.weight
resnet.layer2.1.conv2.bn0.weight
resnet.layer2.1.conv2.bn0.bias
resnet.layer2.1.conv2.fc1.weight
resnet.layer2.1.conv2.fc1.bias
resnet.layer2.1.conv2.bn1.weight
resnet.layer2.1.conv2.bn1.bias
resnet.layer2.1.conv2.fc2.weight
resnet.layer2.1.conv2.fc2.bias
resnet.layer2.1.conv3.weight
resnet.layer2.1.bn3.weight
resnet.layer2.1.bn3.bias
resnet.layer2.2.conv1.weight
resnet.layer2.2.bn1.weight
resnet.layer2.2.bn1.bias
resnet.layer2.2.conv2.conv.weight
resnet.layer2.2.conv2.bn0.weight
resnet.layer2.2.conv2.bn0.bias
resnet.layer2.2.conv2.fc1.weight
resnet.layer2.2.conv2.fc1.bias
resnet.layer2.2.conv2.bn1.weight
resnet.layer2.2.conv2.bn1.bias
resnet.layer2.2.conv2.fc2.weight
resnet.layer2.2.conv2.fc2.bias
resnet.layer2.2.conv3.weight
resnet.layer2.2.bn3.weight
resnet.layer2.2.bn3.bias
resnet.layer2.3.conv1.weight
resnet.layer2.3.bn1.weight
resnet.layer2.3.bn1.bias
resnet.layer2.3.conv2.conv.weight
resnet.layer2.3.conv2.bn0.weight
resnet.layer2.3.conv2.bn0.bias
resnet.layer2.3.conv2.fc1.weight
resnet.layer2.3.conv2.fc1.bias
resnet.layer2.3.conv2.bn1.weight
resnet.layer2.3.conv2.bn1.bias
resnet.layer2.3.conv2.fc2.weight
resnet.layer2.3.conv2.fc2.bias
resnet.layer2.3.conv3.weight
resnet.layer2.3.bn3.weight
resnet.layer2.3.bn3.bias
resnet.layer3.0.conv1.weight
resnet.layer3.0.bn1.weight
resnet.layer3.0.bn1.bias
resnet.layer3.0.conv2.conv.weight
resnet.layer3.0.conv2.bn0.weight
resnet.layer3.0.conv2.bn0.bias
resnet.layer3.0.conv2.fc1.weight
resnet.layer3.0.conv2.fc1.bias
resnet.layer3.0.conv2.bn1.weight
resnet.layer3.0.conv2.bn1.bias
resnet.layer3.0.conv2.fc2.weight
resnet.layer3.0.conv2.fc2.bias
resnet.layer3.0.conv3.weight
resnet.layer3.0.bn3.weight
resnet.layer3.0.bn3.bias
resnet.layer3.0.downsample.1.weight
resnet.layer3.0.downsample.2.weight
resnet.layer3.0.downsample.2.bias
resnet.layer3.1.conv1.weight
resnet.layer3.1.bn1.weight
resnet.layer3.1.bn1.bias
resnet.layer3.1.conv2.conv.weight
resnet.layer3.1.conv2.bn0.weight
resnet.layer3.1.conv2.bn0.bias
resnet.layer3.1.conv2.fc1.weight
resnet.layer3.1.conv2.fc1.bias
resnet.layer3.1.conv2.bn1.weight
resnet.layer3.1.conv2.bn1.bias
resnet.layer3.1.conv2.fc2.weight
resnet.layer3.1.conv2.fc2.bias
resnet.layer3.1.conv3.weight
resnet.layer3.1.bn3.weight
resnet.layer3.1.bn3.bias
resnet.layer3.2.conv1.weight
resnet.layer3.2.bn1.weight
resnet.layer3.2.bn1.bias
resnet.layer3.2.conv2.conv.weight
resnet.layer3.2.conv2.bn0.weight
resnet.layer3.2.conv2.bn0.bias
resnet.layer3.2.conv2.fc1.weight
resnet.layer3.2.conv2.fc1.bias
resnet.layer3.2.conv2.bn1.weight
resnet.layer3.2.conv2.bn1.bias
resnet.layer3.2.conv2.fc2.weight
resnet.layer3.2.conv2.fc2.bias
resnet.layer3.2.conv3.weight
resnet.layer3.2.bn3.weight
resnet.layer3.2.bn3.bias
resnet.layer3.3.conv1.weight
resnet.layer3.3.bn1.weight
resnet.layer3.3.bn1.bias
resnet.layer3.3.conv2.conv.weight
resnet.layer3.3.conv2.bn0.weight
resnet.layer3.3.conv2.bn0.bias
resnet.layer3.3.conv2.fc1.weight
resnet.layer3.3.conv2.fc1.bias
resnet.layer3.3.conv2.bn1.weight
resnet.layer3.3.conv2.bn1.bias
resnet.layer3.3.conv2.fc2.weight
resnet.layer3.3.conv2.fc2.bias
resnet.layer3.3.conv3.weight
resnet.layer3.3.bn3.weight
resnet.layer3.3.bn3.bias
resnet.layer3.4.conv1.weight
resnet.layer3.4.bn1.weight
resnet.layer3.4.bn1.bias
resnet.layer3.4.conv2.conv.weight
resnet.layer3.4.conv2.bn0.weight
resnet.layer3.4.conv2.bn0.bias
resnet.layer3.4.conv2.fc1.weight
resnet.layer3.4.conv2.fc1.bias
resnet.layer3.4.conv2.bn1.weight
resnet.layer3.4.conv2.bn1.bias
resnet.layer3.4.conv2.fc2.weight
resnet.layer3.4.conv2.fc2.bias
resnet.layer3.4.conv3.weight
resnet.layer3.4.bn3.weight
resnet.layer3.4.bn3.bias
resnet.layer3.5.conv1.weight
resnet.layer3.5.bn1.weight
resnet.layer3.5.bn1.bias
resnet.layer3.5.conv2.conv.weight
resnet.layer3.5.conv2.bn0.weight
resnet.layer3.5.conv2.bn0.bias
resnet.layer3.5.conv2.fc1.weight
resnet.layer3.5.conv2.fc1.bias
resnet.layer3.5.conv2.bn1.weight
resnet.layer3.5.conv2.bn1.bias
resnet.layer3.5.conv2.fc2.weight
resnet.layer3.5.conv2.fc2.bias
resnet.layer3.5.conv3.weight
resnet.layer3.5.bn3.weight
resnet.layer3.5.bn3.bias
resnet.layer3.6.conv1.weight
resnet.layer3.6.bn1.weight
resnet.layer3.6.bn1.bias
resnet.layer3.6.conv2.conv.weight
resnet.layer3.6.conv2.bn0.weight
resnet.layer3.6.conv2.bn0.bias
resnet.layer3.6.conv2.fc1.weight
resnet.layer3.6.conv2.fc1.bias
resnet.layer3.6.conv2.bn1.weight
resnet.layer3.6.conv2.bn1.bias
resnet.layer3.6.conv2.fc2.weight
resnet.layer3.6.conv2.fc2.bias
resnet.layer3.6.conv3.weight
resnet.layer3.6.bn3.weight
resnet.layer3.6.bn3.bias
resnet.layer3.7.conv1.weight
resnet.layer3.7.bn1.weight
resnet.layer3.7.bn1.bias
resnet.layer3.7.conv2.conv.weight
resnet.layer3.7.conv2.bn0.weight
resnet.layer3.7.conv2.bn0.bias
resnet.layer3.7.conv2.fc1.weight
resnet.layer3.7.conv2.fc1.bias
resnet.layer3.7.conv2.bn1.weight
resnet.layer3.7.conv2.bn1.bias
resnet.layer3.7.conv2.fc2.weight
resnet.layer3.7.conv2.fc2.bias
resnet.layer3.7.conv3.weight
resnet.layer3.7.bn3.weight
resnet.layer3.7.bn3.bias
resnet.layer3.8.conv1.weight
resnet.layer3.8.bn1.weight
resnet.layer3.8.bn1.bias
resnet.layer3.8.conv2.conv.weight
resnet.layer3.8.conv2.bn0.weight
resnet.layer3.8.conv2.bn0.bias
resnet.layer3.8.conv2.fc1.weight
resnet.layer3.8.conv2.fc1.bias
resnet.layer3.8.conv2.bn1.weight
resnet.layer3.8.conv2.bn1.bias
resnet.layer3.8.conv2.fc2.weight
resnet.layer3.8.conv2.fc2.bias
resnet.layer3.8.conv3.weight
resnet.layer3.8.bn3.weight
resnet.layer3.8.bn3.bias
resnet.layer3.9.conv1.weight
resnet.layer3.9.bn1.weight
resnet.layer3.9.bn1.bias
resnet.layer3.9.conv2.conv.weight
resnet.layer3.9.conv2.bn0.weight
resnet.layer3.9.conv2.bn0.bias
resnet.layer3.9.conv2.fc1.weight
resnet.layer3.9.conv2.fc1.bias
resnet.layer3.9.conv2.bn1.weight
resnet.layer3.9.conv2.bn1.bias
resnet.layer3.9.conv2.fc2.weight
resnet.layer3.9.conv2.fc2.bias
resnet.layer3.9.conv3.weight
resnet.layer3.9.bn3.weight
resnet.layer3.9.bn3.bias
resnet.layer3.10.conv1.weight
resnet.layer3.10.bn1.weight
resnet.layer3.10.bn1.bias
resnet.layer3.10.conv2.conv.weight
resnet.layer3.10.conv2.bn0.weight
resnet.layer3.10.conv2.bn0.bias
resnet.layer3.10.conv2.fc1.weight
resnet.layer3.10.conv2.fc1.bias
resnet.layer3.10.conv2.bn1.weight
resnet.layer3.10.conv2.bn1.bias
resnet.layer3.10.conv2.fc2.weight
resnet.layer3.10.conv2.fc2.bias
resnet.layer3.10.conv3.weight
resnet.layer3.10.bn3.weight
resnet.layer3.10.bn3.bias
resnet.layer3.11.conv1.weight
resnet.layer3.11.bn1.weight
resnet.layer3.11.bn1.bias
resnet.layer3.11.conv2.conv.weight
resnet.layer3.11.conv2.bn0.weight
resnet.layer3.11.conv2.bn0.bias
resnet.layer3.11.conv2.fc1.weight
resnet.layer3.11.conv2.fc1.bias
resnet.layer3.11.conv2.bn1.weight
resnet.layer3.11.conv2.bn1.bias
resnet.layer3.11.conv2.fc2.weight
resnet.layer3.11.conv2.fc2.bias
resnet.layer3.11.conv3.weight
resnet.layer3.11.bn3.weight
resnet.layer3.11.bn3.bias
resnet.layer3.12.conv1.weight
resnet.layer3.12.bn1.weight
resnet.layer3.12.bn1.bias
resnet.layer3.12.conv2.conv.weight
resnet.layer3.12.conv2.bn0.weight
resnet.layer3.12.conv2.bn0.bias
resnet.layer3.12.conv2.fc1.weight
resnet.layer3.12.conv2.fc1.bias
resnet.layer3.12.conv2.bn1.weight
resnet.layer3.12.conv2.bn1.bias
resnet.layer3.12.conv2.fc2.weight
resnet.layer3.12.conv2.fc2.bias
resnet.layer3.12.conv3.weight
resnet.layer3.12.bn3.weight
resnet.layer3.12.bn3.bias
resnet.layer3.13.conv1.weight
resnet.layer3.13.bn1.weight
resnet.layer3.13.bn1.bias
resnet.layer3.13.conv2.conv.weight
resnet.layer3.13.conv2.bn0.weight
resnet.layer3.13.conv2.bn0.bias
resnet.layer3.13.conv2.fc1.weight
resnet.layer3.13.conv2.fc1.bias
resnet.layer3.13.conv2.bn1.weight
resnet.layer3.13.conv2.bn1.bias
resnet.layer3.13.conv2.fc2.weight
resnet.layer3.13.conv2.fc2.bias
resnet.layer3.13.conv3.weight
resnet.layer3.13.bn3.weight
resnet.layer3.13.bn3.bias
resnet.layer3.14.conv1.weight
resnet.layer3.14.bn1.weight
resnet.layer3.14.bn1.bias
resnet.layer3.14.conv2.conv.weight
resnet.layer3.14.conv2.bn0.weight
resnet.layer3.14.conv2.bn0.bias
resnet.layer3.14.conv2.fc1.weight
resnet.layer3.14.conv2.fc1.bias
resnet.layer3.14.conv2.bn1.weight
resnet.layer3.14.conv2.bn1.bias
resnet.layer3.14.conv2.fc2.weight
resnet.layer3.14.conv2.fc2.bias
resnet.layer3.14.conv3.weight
resnet.layer3.14.bn3.weight
resnet.layer3.14.bn3.bias
resnet.layer3.15.conv1.weight
resnet.layer3.15.bn1.weight
resnet.layer3.15.bn1.bias
resnet.layer3.15.conv2.conv.weight
resnet.layer3.15.conv2.bn0.weight
resnet.layer3.15.conv2.bn0.bias
resnet.layer3.15.conv2.fc1.weight
resnet.layer3.15.conv2.fc1.bias
resnet.layer3.15.conv2.bn1.weight
resnet.layer3.15.conv2.bn1.bias
resnet.layer3.15.conv2.fc2.weight
resnet.layer3.15.conv2.fc2.bias
resnet.layer3.15.conv3.weight
resnet.layer3.15.bn3.weight
resnet.layer3.15.bn3.bias
resnet.layer3.16.conv1.weight
resnet.layer3.16.bn1.weight
resnet.layer3.16.bn1.bias
resnet.layer3.16.conv2.conv.weight
resnet.layer3.16.conv2.bn0.weight
resnet.layer3.16.conv2.bn0.bias
resnet.layer3.16.conv2.fc1.weight
resnet.layer3.16.conv2.fc1.bias
resnet.layer3.16.conv2.bn1.weight
resnet.layer3.16.conv2.bn1.bias
resnet.layer3.16.conv2.fc2.weight
resnet.layer3.16.conv2.fc2.bias
resnet.layer3.16.conv3.weight
resnet.layer3.16.bn3.weight
resnet.layer3.16.bn3.bias
resnet.layer3.17.conv1.weight
resnet.layer3.17.bn1.weight
resnet.layer3.17.bn1.bias
resnet.layer3.17.conv2.conv.weight
resnet.layer3.17.conv2.bn0.weight
resnet.layer3.17.conv2.bn0.bias
resnet.layer3.17.conv2.fc1.weight
resnet.layer3.17.conv2.fc1.bias
resnet.layer3.17.conv2.bn1.weight
resnet.layer3.17.conv2.bn1.bias
resnet.layer3.17.conv2.fc2.weight
resnet.layer3.17.conv2.fc2.bias
resnet.layer3.17.conv3.weight
resnet.layer3.17.bn3.weight
resnet.layer3.17.bn3.bias
resnet.layer3.18.conv1.weight
resnet.layer3.18.bn1.weight
resnet.layer3.18.bn1.bias
resnet.layer3.18.conv2.conv.weight
resnet.layer3.18.conv2.bn0.weight
resnet.layer3.18.conv2.bn0.bias
resnet.layer3.18.conv2.fc1.weight
resnet.layer3.18.conv2.fc1.bias
resnet.layer3.18.conv2.bn1.weight
resnet.layer3.18.conv2.bn1.bias
resnet.layer3.18.conv2.fc2.weight
resnet.layer3.18.conv2.fc2.bias
resnet.layer3.18.conv3.weight
resnet.layer3.18.bn3.weight
resnet.layer3.18.bn3.bias
resnet.layer3.19.conv1.weight
resnet.layer3.19.bn1.weight
resnet.layer3.19.bn1.bias
resnet.layer3.19.conv2.conv.weight
resnet.layer3.19.conv2.bn0.weight
resnet.layer3.19.conv2.bn0.bias
resnet.layer3.19.conv2.fc1.weight
resnet.layer3.19.conv2.fc1.bias
resnet.layer3.19.conv2.bn1.weight
resnet.layer3.19.conv2.bn1.bias
resnet.layer3.19.conv2.fc2.weight
resnet.layer3.19.conv2.fc2.bias
resnet.layer3.19.conv3.weight
resnet.layer3.19.bn3.weight
resnet.layer3.19.bn3.bias
resnet.layer3.20.conv1.weight
resnet.layer3.20.bn1.weight
resnet.layer3.20.bn1.bias
resnet.layer3.20.conv2.conv.weight
resnet.layer3.20.conv2.bn0.weight
resnet.layer3.20.conv2.bn0.bias
resnet.layer3.20.conv2.fc1.weight
resnet.layer3.20.conv2.fc1.bias
resnet.layer3.20.conv2.bn1.weight
resnet.layer3.20.conv2.bn1.bias
resnet.layer3.20.conv2.fc2.weight
resnet.layer3.20.conv2.fc2.bias
resnet.layer3.20.conv3.weight
resnet.layer3.20.bn3.weight
resnet.layer3.20.bn3.bias
resnet.layer3.21.conv1.weight
resnet.layer3.21.bn1.weight
resnet.layer3.21.bn1.bias
resnet.layer3.21.conv2.conv.weight
resnet.layer3.21.conv2.bn0.weight
resnet.layer3.21.conv2.bn0.bias
resnet.layer3.21.conv2.fc1.weight
resnet.layer3.21.conv2.fc1.bias
resnet.layer3.21.conv2.bn1.weight
resnet.layer3.21.conv2.bn1.bias
resnet.layer3.21.conv2.fc2.weight
resnet.layer3.21.conv2.fc2.bias
resnet.layer3.21.conv3.weight
resnet.layer3.21.bn3.weight
resnet.layer3.21.bn3.bias
resnet.layer3.22.conv1.weight
resnet.layer3.22.bn1.weight
resnet.layer3.22.bn1.bias
resnet.layer3.22.conv2.conv.weight
resnet.layer3.22.conv2.bn0.weight
resnet.layer3.22.conv2.bn0.bias
resnet.layer3.22.conv2.fc1.weight
resnet.layer3.22.conv2.fc1.bias
resnet.layer3.22.conv2.bn1.weight
resnet.layer3.22.conv2.bn1.bias
resnet.layer3.22.conv2.fc2.weight
resnet.layer3.22.conv2.fc2.bias
resnet.layer3.22.conv3.weight
resnet.layer3.22.bn3.weight
resnet.layer3.22.bn3.bias
resnet.layer4.0.conv1.weight
resnet.layer4.0.bn1.weight
resnet.layer4.0.bn1.bias
resnet.layer4.0.conv2.conv.weight
resnet.layer4.0.conv2.bn0.weight
resnet.layer4.0.conv2.bn0.bias
resnet.layer4.0.conv2.fc1.weight
resnet.layer4.0.conv2.fc1.bias
resnet.layer4.0.conv2.bn1.weight
resnet.layer4.0.conv2.bn1.bias
resnet.layer4.0.conv2.fc2.weight
resnet.layer4.0.conv2.fc2.bias
resnet.layer4.0.conv3.weight
resnet.layer4.0.bn3.weight
resnet.layer4.0.bn3.bias
resnet.layer4.0.downsample.1.weight
resnet.layer4.0.downsample.2.weight
resnet.layer4.0.downsample.2.bias
resnet.layer4.1.conv1.weight
resnet.layer4.1.bn1.weight
resnet.layer4.1.bn1.bias
resnet.layer4.1.conv2.conv.weight
resnet.layer4.1.conv2.bn0.weight
resnet.layer4.1.conv2.bn0.bias
resnet.layer4.1.conv2.fc1.weight
resnet.layer4.1.conv2.fc1.bias
resnet.layer4.1.conv2.bn1.weight
resnet.layer4.1.conv2.bn1.bias
resnet.layer4.1.conv2.fc2.weight
resnet.layer4.1.conv2.fc2.bias
resnet.layer4.1.conv3.weight
resnet.layer4.1.bn3.weight
resnet.layer4.1.bn3.bias
resnet.layer4.2.conv1.weight
resnet.layer4.2.bn1.weight
resnet.layer4.2.bn1.bias
resnet.layer4.2.conv2.conv.weight
resnet.layer4.2.conv2.bn0.weight
resnet.layer4.2.conv2.bn0.bias
resnet.layer4.2.conv2.fc1.weight
resnet.layer4.2.conv2.fc1.bias
resnet.layer4.2.conv2.bn1.weight
resnet.layer4.2.conv2.bn1.bias
resnet.layer4.2.conv2.fc2.weight
resnet.layer4.2.conv2.fc2.bias
resnet.layer4.2.conv3.weight
resnet.layer4.2.bn3.weight
resnet.layer4.2.bn3.bias
resnet.fc.weight
resnet.fc.bias

但是,如果全部參數都參與訓練的話,很容易過擬合,並且,GPU顯存也傷不起,畢竟模型是真的大~~~

下面我只對resnest101的layer4層進行訓練,凍結了其它層。如下:

model = Net().cuda()
for name, param in model.named_parameters():  # 查看可優化的參數有哪些
    # if param.requires_grad:
    if 'layer4' in name:
        print(name)
        continue
    param.requires_grad = False

model = nn.DataParallel(model).cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(filter(lambda p:p.requires_grad, model.parameters()) , 0.01)

下面這行代碼是定義優化器,並只對過濾出的參數進行梯度更新,學習率設置爲0.01。

optimizer = torch.optim.SGD(filter(lambda p:p.requires_grad, model.parameters()) , 0.01)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章