【pytorch】利用requires_grad凍結部分網絡參數

代碼示例:

import torch
import torch.nn as nn


class a1(torch.nn.Module):
    def __init__(self):
        super(a1, self).__init__()
        self.l1 = nn.Linear(3, 2)


class aa(a1):
    def __init__(self):
        super(aa, self).__init__()
        self.a = ''
        self.b = 'c'
        self.l2 = nn.Linear(2, 2)
        self.l3 = nn.Linear(2, 1)

    def forward(self, x):
        e1 = self.l1(x)
        e2 = self.l2(e1)
        e3 = self.l3(e2)
        return e3


a = aa()
freeze_layers = ['l2', 'l3']
opt_param = []
for name, module in a._modules.items():
    if name not in freeze_layers:
        for p in module.parameters():
            opt_param.append(p)
    else:
        for p in module.parameters():
            p.requires_grad = False

print('original parameters:\n', list(a.parameters()))

x = torch.tensor([[1, 2, 3], [4, 5, 6], [-1, -2, -3], [-2, -4, -5]], dtype=torch.float)
y = torch.tensor([1, 1, -1, -1], dtype=torch.float)
y = y.view(4, -1)
y_ = a(x)
print('y_', y_)
celoss = nn.MSELoss()
loss = celoss(y, y_)
opt = torch.optim.Adam(opt_param)
loss.backward()
opt.step()

print('new parameters:\n', list(a.parameters()))

輸出:

original parameters:
 [Parameter containing:
tensor([[-0.0965, -0.3446,  0.0866],
        [-0.1677, -0.2664,  0.5007]], requires_grad=True), Parameter containing:
tensor([ 0.2607, -0.3101], requires_grad=True), Parameter containing:
tensor([[0.5237, 0.1328],
        [0.6994, 0.1014]]), Parameter containing:
tensor([ 0.2662, -0.1830]), Parameter containing:
tensor([[-0.0303,  0.1869]]), Parameter containing:
tensor([0.1692])]

y_  tensor([[ 0.1037],
        [-0.0155],
        [ 0.2007],
        [ 0.2665]], grad_fn=<AddmmBackward>)
new parameters:
 [Parameter containing:
tensor([[-0.0955, -0.3436,  0.0876],
        [-0.1667, -0.2654,  0.5017]], requires_grad=True), Parameter containing:
tensor([ 0.2597, -0.3111], requires_grad=True), Parameter containing:
tensor([[0.5237, 0.1328],
        [0.6994, 0.1014]]), Parameter containing:
tensor([ 0.2662, -0.1830]), Parameter containing:
tensor([[-0.0303,  0.1869]]), Parameter containing:
tensor([0.1692])]

 

可以看到l2和l3兩層網絡被凍結了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章