關於auto-gradient機制與detach函數

先講一下葉子節點和非葉子節點的定義:

葉子節點(張量的is_leaf)屬性值爲True,grad_fn爲None,葉子節點有兩種情況:

第一種:由用戶自行創建的節點(即不是由運算而來):

a = torch.rand(5, 5, requires_grad=False)
b = torch.rand(5, 5, requires_grad=False)
c = torch.rand(5, 5, requires_grad=True)

print(a.is_leaf, b.is_leaf, c.is_leaf)

out:True True True

這裏a、b、c都是葉子節點,可見,只要是用戶創建的節點,不管requires_grad是否爲True,都被認定爲葉子節點。 

import torch
import torch.nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(20, 100)
        self.conv1 = nn.Conv1d(3, 20, 1)
        self.linear2 = nn.Linear(100, 3)

    def forward(self, x):
        x = self.linear1(x)
        x = self.conv1(x)
        x = self.linear2(x)
        return x


net = Net()
# 利用這裏的.named_parameters()函數可以查看網絡某一層的梯度或者相關信息
for name, param in net.named_parameters():
    print(param.is_leaf)


loss = torch.sum(net(torch.rand(3, 3, 20)))
loss.backward()
print(loss)

for name, param in net.named_parameters():
    print(param.grad)

輸出結果:

True
True
True
True
True
True
tensor(17.9995, grad_fn=<SumBackward0>)
tensor([[ 0.0022,  0.0038,  0.0032,  ...,  0.0040,  0.0043,  0.0028],
        [-0.2095, -0.3591, -0.3042,  ..., -0.3819, -0.4080, -0.2680],
        [-0.2490, -0.4268, -0.3616,  ..., -0.4540, -0.4849, -0.3185],
        ...,
        [-0.6358, -1.0897, -0.9233,  ..., -1.1591, -1.2382, -0.8132],
        [ 0.7041,  1.2068,  1.0225,  ...,  1.2837,  1.3713,  0.9006],
        [ 0.7866,  1.3482,  1.1423,  ...,  1.4340,  1.5319,  1.0060]])
tensor([ 0.0065, -0.6185, -0.7351,  1.4288,  3.1651,  0.5760, -1.5725, -1.7545,
         0.3321,  3.0410,  1.4884, -1.7543,  1.7647,  1.5608,  3.0300, -1.2516,
         1.4375, -3.0647,  2.1893,  1.0579, -2.0025, -3.3547,  1.8892, -0.3292,
        -1.3015,  0.2360,  0.8959,  0.5413,  0.1860,  1.6856, -1.0035,  3.7435,
        -2.3033, -0.7822,  1.3784,  1.3075, -2.5018, -1.5018,  2.3359,  1.1987,
         0.9791,  1.4405,  0.5719, -0.7628,  1.5323, -2.3249, -0.4792, -2.3890,
        -2.9083, -2.4543, -0.8187, -0.5089, -1.1570, -3.4229, -1.1228, -1.5827,
        -2.3074,  2.3675, -3.7123, -3.3584, -0.0685,  0.3362,  2.4232,  1.3508,
        -2.6919,  0.0660,  1.4497, -3.5126,  1.8065, -0.1005, -0.7967,  0.1015,
        -4.0614, -1.3206,  0.6517,  0.7344,  0.8240, -1.5529,  0.6521, -0.9247,
         3.3348, -2.6734, -0.8511,  0.4197, -0.0914,  1.1393,  0.6017,  0.4064,
        -1.8118,  0.4393, -5.3718,  1.0121, -2.0927, -0.2776, -2.6336, -3.9693,
        -0.8038, -1.8769,  2.0786,  2.3221])
tensor([[[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]],

        [[-0.2806],
         [ 0.6753],
         [ 0.0123]]])
tensor([-4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512,
        -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512,
        -4.1512, -4.1512, -4.1512, -4.1512])
tensor([[-17.4855,   7.2912, -10.7309,   3.8056,  -7.4810,   1.3819,  -3.4382,
          -3.0156,   1.1380,  15.0192,   7.2064,   0.7653,  -5.7029,   1.2748,
          -1.6344,  -0.8372,  -2.9532,   4.0480,  -2.0697,   0.3331,  -7.6875,
           3.3576,   4.7645,   1.1968,  -1.2719,  -2.7470,   7.1204, -13.5199,
          -4.0260,   1.8613,  -2.0579,  11.1797,  -3.6872,   3.0268,  -1.5994,
          -0.0275,   4.1218,   6.7948,  -4.1513,   2.0530,   0.6553,  -3.7766,
         -12.5851,  -5.7439, -10.5449,  -7.6828,  -4.0595,  -0.6027,  -7.1036,
          -3.3682, -11.7223,   1.3113,   0.8308,   2.0291,   2.5506,  -2.5968,
          -2.7273,  -2.3901,  -5.1041,   4.1714,   3.6672,  -2.9201,   4.1873,
          -0.9386,  16.7366,   3.4055,  -3.1839,  -0.1989,  -2.5431,  -0.8971,
          -4.5289,  14.4361,  -1.6026,   1.2648,   7.6281,  -7.9868,  -1.0889,
           0.5460,   9.4176,   0.7032, -10.7664,  -5.9075,   1.7435,  -6.4666,
          -5.5959,   8.3148,  -1.2747,  -1.9817,  -6.3998,  -8.4261,  -8.2749,
          -5.5242,   6.4204,  -7.2764,   1.4699,   3.1666,  -4.0820,   4.4864,
           6.4716,  10.8337],
        [-17.4855,   7.2912, -10.7309,   3.8056,  -7.4810,   1.3819,  -3.4382,
          -3.0156,   1.1380,  15.0192,   7.2064,   0.7653,  -5.7029,   1.2748,
          -1.6344,  -0.8372,  -2.9532,   4.0480,  -2.0697,   0.3331,  -7.6875,
           3.3576,   4.7645,   1.1968,  -1.2719,  -2.7470,   7.1204, -13.5199,
          -4.0260,   1.8613,  -2.0579,  11.1797,  -3.6872,   3.0268,  -1.5994,
          -0.0275,   4.1218,   6.7948,  -4.1513,   2.0530,   0.6553,  -3.7766,
         -12.5851,  -5.7439, -10.5449,  -7.6828,  -4.0595,  -0.6027,  -7.1036,
          -3.3682, -11.7223,   1.3113,   0.8308,   2.0291,   2.5506,  -2.5968,
          -2.7273,  -2.3901,  -5.1041,   4.1714,   3.6672,  -2.9201,   4.1873,
          -0.9386,  16.7366,   3.4055,  -3.1839,  -0.1989,  -2.5431,  -0.8971,
          -4.5289,  14.4361,  -1.6026,   1.2648,   7.6281,  -7.9868,  -1.0889,
           0.5460,   9.4176,   0.7032, -10.7664,  -5.9075,   1.7435,  -6.4666,
          -5.5959,   8.3148,  -1.2747,  -1.9817,  -6.3998,  -8.4261,  -8.2749,
          -5.5242,   6.4204,  -7.2764,   1.4699,   3.1666,  -4.0820,   4.4864,
           6.4716,  10.8337],
        [-17.4855,   7.2912, -10.7309,   3.8056,  -7.4810,   1.3819,  -3.4382,
          -3.0156,   1.1380,  15.0192,   7.2064,   0.7653,  -5.7029,   1.2748,
          -1.6344,  -0.8372,  -2.9532,   4.0480,  -2.0697,   0.3331,  -7.6875,
           3.3576,   4.7645,   1.1968,  -1.2719,  -2.7470,   7.1204, -13.5199,
          -4.0260,   1.8613,  -2.0579,  11.1797,  -3.6872,   3.0268,  -1.5994,
          -0.0275,   4.1218,   6.7948,  -4.1513,   2.0530,   0.6553,  -3.7766,
         -12.5851,  -5.7439, -10.5449,  -7.6828,  -4.0595,  -0.6027,  -7.1036,
          -3.3682, -11.7223,   1.3113,   0.8308,   2.0291,   2.5506,  -2.5968,
          -2.7273,  -2.3901,  -5.1041,   4.1714,   3.6672,  -2.9201,   4.1873,
          -0.9386,  16.7366,   3.4055,  -3.1839,  -0.1989,  -2.5431,  -0.8971,
          -4.5289,  14.4361,  -1.6026,   1.2648,   7.6281,  -7.9868,  -1.0889,
           0.5460,   9.4176,   0.7032, -10.7664,  -5.9075,   1.7435,  -6.4666,
          -5.5959,   8.3148,  -1.2747,  -1.9817,  -6.3998,  -8.4261,  -8.2749,
          -5.5242,   6.4204,  -7.2764,   1.4699,   3.1666,  -4.0820,   4.4864,
           6.4716,  10.8337]])
tensor([60., 60., 60.])

第二種: 由requires_grad爲False的張量(全部都爲False)運算而來的節點:

f = a + b

print(f.is_leaf, f.grad_fn)

out:True None

 這裏f並不是由用戶創建的,而是由其他張量運算而來,卻也被認定爲葉子節點了,其grad_fn也爲None。

總結一下:requires_grad爲False或是由用戶創建的張量(即使爲True)被認爲是葉子節點。在計算梯度時,requires_grad爲True的節點的梯度會被計算,然後其中非葉子節點的梯度會被清空用以節省內存,因爲一般而言,只有葉子節點的梯度纔是有價值的(藉此優化)。

import torch

a = torch.rand(5, 5, requires_grad=False)
b = torch.rand(5, 5, requires_grad=False)
c = torch.rand(5, 5, requires_grad=True)

f = torch.add(a, b)
g = b + c

f_dt = torch.detach(f)

h = f_dt + a # f_dt和a的requires_grad都爲False,故h的requires_grad爲False,爲葉子節點

y = 2 * h * g

y.backward(torch.ones_like(y))

# h是本來就沒有梯度,g是計算完梯度後被釋放了
print(h.grad, h.is_leaf, h.requires_grad, h.grad_fn)
print(g.grad, g.is_leaf, g.requires_grad, g.grad_fn)
print(y.grad_fn, y.grad_fn.next_functions)
print(c.grad, b.grad)

以下是輸出: 

None True False None
None False True <AddBackward0 object at 0x0000018A5B437D68>
<MulBackward0 object at 0x0000018A5B437D68> ((None, 0), (<AddBackward0 object at 0x0000018A577320B8>, 0))
tensor([[4.5234, 4.5274, 2.3929, 4.2744, 1.5618],
        [2.8044, 3.5590, 1.5690, 2.8343, 4.1017],
        [2.5466, 3.9801, 4.4518, 3.5405, 3.0842],
        [4.2424, 3.1682, 3.4410, 3.4308, 4.1553],
        [3.3562, 0.4991, 3.6946, 0.6475, 2.7033]]) None

非葉子節點,如果我們非要獲得它的梯度,可以通過retain_grad方法或者hook。

hook:

import torch
# import torch.nn as nn


grads = {}


def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook


x = torch.rand(1, requires_grad=True)
y = 3 * x
z = y ** 2

y.register_hook(save_grad('y'))

z.backward()

print(grads['y'])

我們平時的網絡bp都是指對w和b進行反傳,沒有對torch.tensor的參數進行bp的。在detach之後,相當於是那個參數不需要梯度,可是我們不是本來就不需要它的梯度嗎,我們只需要卷積層參數的梯度纔是。

圖片數據轉化爲的tensor都是require_grad=False的葉子節點,但是經過requires_grad=True的卷積層後輸出的特徵圖都是requires_grad等於True(is_leaf==False)的非葉子節點了。

最後的loss在backward()後,loss的grad_fn屬性被調用,執行求導的操作。這個操作將遍歷grad_fn的next_functions,然後取出裏面的Function(AccumulateGrad),執行求導操作。這部分是一個遞歸的過程,直到最後類型爲葉子節點(即輸入的圖片或被detach分離下來的節點)。計算出結果之後,將結果保存到節點們對應的variable這個變量(儲存在loss的grad_fn的遞歸過程中)所引用的對象(節點們)的grad這個屬性裏面。接到結束,所有的非葉子節點的grad變量都得到了更新。

其實不管是特徵圖還是卷積裏面的變量,它們是requires_grad == True的,即loss在backward的時候會求出它們所有的grad。但是我們的optimizer在生成的時候是傳入了需要優化的參數model.parameter的,所以只有卷積等模型的參數的值會被更新,

關於detach函數:它會生成一個原張量新的視圖,requires_grad永遠爲False,後面的loss在backward到這裏的時候發現了葉子節點,就不會再往傳播了,所以前面的所有節點的grad值都不會受到這條支路的影響,所以自然也不會被這條支路的loss所優化。這就是detach所謂的阻斷傳播的作用。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章