先講一下葉子節點和非葉子節點的定義:
葉子節點(張量的is_leaf)屬性值爲True,grad_fn爲None,葉子節點有兩種情況:
第一種:由用戶自行創建的節點(即不是由運算而來):
a = torch.rand(5, 5, requires_grad=False)
b = torch.rand(5, 5, requires_grad=False)
c = torch.rand(5, 5, requires_grad=True)
print(a.is_leaf, b.is_leaf, c.is_leaf)
out:True True True
這裏a、b、c都是葉子節點,可見,只要是用戶創建的節點,不管requires_grad是否爲True,都被認定爲葉子節點。
import torch
import torch.nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(20, 100)
self.conv1 = nn.Conv1d(3, 20, 1)
self.linear2 = nn.Linear(100, 3)
def forward(self, x):
x = self.linear1(x)
x = self.conv1(x)
x = self.linear2(x)
return x
net = Net()
# 利用這裏的.named_parameters()函數可以查看網絡某一層的梯度或者相關信息
for name, param in net.named_parameters():
print(param.is_leaf)
loss = torch.sum(net(torch.rand(3, 3, 20)))
loss.backward()
print(loss)
for name, param in net.named_parameters():
print(param.grad)
輸出結果:
True
True
True
True
True
True
tensor(17.9995, grad_fn=<SumBackward0>)
tensor([[ 0.0022, 0.0038, 0.0032, ..., 0.0040, 0.0043, 0.0028],
[-0.2095, -0.3591, -0.3042, ..., -0.3819, -0.4080, -0.2680],
[-0.2490, -0.4268, -0.3616, ..., -0.4540, -0.4849, -0.3185],
...,
[-0.6358, -1.0897, -0.9233, ..., -1.1591, -1.2382, -0.8132],
[ 0.7041, 1.2068, 1.0225, ..., 1.2837, 1.3713, 0.9006],
[ 0.7866, 1.3482, 1.1423, ..., 1.4340, 1.5319, 1.0060]])
tensor([ 0.0065, -0.6185, -0.7351, 1.4288, 3.1651, 0.5760, -1.5725, -1.7545,
0.3321, 3.0410, 1.4884, -1.7543, 1.7647, 1.5608, 3.0300, -1.2516,
1.4375, -3.0647, 2.1893, 1.0579, -2.0025, -3.3547, 1.8892, -0.3292,
-1.3015, 0.2360, 0.8959, 0.5413, 0.1860, 1.6856, -1.0035, 3.7435,
-2.3033, -0.7822, 1.3784, 1.3075, -2.5018, -1.5018, 2.3359, 1.1987,
0.9791, 1.4405, 0.5719, -0.7628, 1.5323, -2.3249, -0.4792, -2.3890,
-2.9083, -2.4543, -0.8187, -0.5089, -1.1570, -3.4229, -1.1228, -1.5827,
-2.3074, 2.3675, -3.7123, -3.3584, -0.0685, 0.3362, 2.4232, 1.3508,
-2.6919, 0.0660, 1.4497, -3.5126, 1.8065, -0.1005, -0.7967, 0.1015,
-4.0614, -1.3206, 0.6517, 0.7344, 0.8240, -1.5529, 0.6521, -0.9247,
3.3348, -2.6734, -0.8511, 0.4197, -0.0914, 1.1393, 0.6017, 0.4064,
-1.8118, 0.4393, -5.3718, 1.0121, -2.0927, -0.2776, -2.6336, -3.9693,
-0.8038, -1.8769, 2.0786, 2.3221])
tensor([[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]],
[[-0.2806],
[ 0.6753],
[ 0.0123]]])
tensor([-4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512,
-4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512, -4.1512,
-4.1512, -4.1512, -4.1512, -4.1512])
tensor([[-17.4855, 7.2912, -10.7309, 3.8056, -7.4810, 1.3819, -3.4382,
-3.0156, 1.1380, 15.0192, 7.2064, 0.7653, -5.7029, 1.2748,
-1.6344, -0.8372, -2.9532, 4.0480, -2.0697, 0.3331, -7.6875,
3.3576, 4.7645, 1.1968, -1.2719, -2.7470, 7.1204, -13.5199,
-4.0260, 1.8613, -2.0579, 11.1797, -3.6872, 3.0268, -1.5994,
-0.0275, 4.1218, 6.7948, -4.1513, 2.0530, 0.6553, -3.7766,
-12.5851, -5.7439, -10.5449, -7.6828, -4.0595, -0.6027, -7.1036,
-3.3682, -11.7223, 1.3113, 0.8308, 2.0291, 2.5506, -2.5968,
-2.7273, -2.3901, -5.1041, 4.1714, 3.6672, -2.9201, 4.1873,
-0.9386, 16.7366, 3.4055, -3.1839, -0.1989, -2.5431, -0.8971,
-4.5289, 14.4361, -1.6026, 1.2648, 7.6281, -7.9868, -1.0889,
0.5460, 9.4176, 0.7032, -10.7664, -5.9075, 1.7435, -6.4666,
-5.5959, 8.3148, -1.2747, -1.9817, -6.3998, -8.4261, -8.2749,
-5.5242, 6.4204, -7.2764, 1.4699, 3.1666, -4.0820, 4.4864,
6.4716, 10.8337],
[-17.4855, 7.2912, -10.7309, 3.8056, -7.4810, 1.3819, -3.4382,
-3.0156, 1.1380, 15.0192, 7.2064, 0.7653, -5.7029, 1.2748,
-1.6344, -0.8372, -2.9532, 4.0480, -2.0697, 0.3331, -7.6875,
3.3576, 4.7645, 1.1968, -1.2719, -2.7470, 7.1204, -13.5199,
-4.0260, 1.8613, -2.0579, 11.1797, -3.6872, 3.0268, -1.5994,
-0.0275, 4.1218, 6.7948, -4.1513, 2.0530, 0.6553, -3.7766,
-12.5851, -5.7439, -10.5449, -7.6828, -4.0595, -0.6027, -7.1036,
-3.3682, -11.7223, 1.3113, 0.8308, 2.0291, 2.5506, -2.5968,
-2.7273, -2.3901, -5.1041, 4.1714, 3.6672, -2.9201, 4.1873,
-0.9386, 16.7366, 3.4055, -3.1839, -0.1989, -2.5431, -0.8971,
-4.5289, 14.4361, -1.6026, 1.2648, 7.6281, -7.9868, -1.0889,
0.5460, 9.4176, 0.7032, -10.7664, -5.9075, 1.7435, -6.4666,
-5.5959, 8.3148, -1.2747, -1.9817, -6.3998, -8.4261, -8.2749,
-5.5242, 6.4204, -7.2764, 1.4699, 3.1666, -4.0820, 4.4864,
6.4716, 10.8337],
[-17.4855, 7.2912, -10.7309, 3.8056, -7.4810, 1.3819, -3.4382,
-3.0156, 1.1380, 15.0192, 7.2064, 0.7653, -5.7029, 1.2748,
-1.6344, -0.8372, -2.9532, 4.0480, -2.0697, 0.3331, -7.6875,
3.3576, 4.7645, 1.1968, -1.2719, -2.7470, 7.1204, -13.5199,
-4.0260, 1.8613, -2.0579, 11.1797, -3.6872, 3.0268, -1.5994,
-0.0275, 4.1218, 6.7948, -4.1513, 2.0530, 0.6553, -3.7766,
-12.5851, -5.7439, -10.5449, -7.6828, -4.0595, -0.6027, -7.1036,
-3.3682, -11.7223, 1.3113, 0.8308, 2.0291, 2.5506, -2.5968,
-2.7273, -2.3901, -5.1041, 4.1714, 3.6672, -2.9201, 4.1873,
-0.9386, 16.7366, 3.4055, -3.1839, -0.1989, -2.5431, -0.8971,
-4.5289, 14.4361, -1.6026, 1.2648, 7.6281, -7.9868, -1.0889,
0.5460, 9.4176, 0.7032, -10.7664, -5.9075, 1.7435, -6.4666,
-5.5959, 8.3148, -1.2747, -1.9817, -6.3998, -8.4261, -8.2749,
-5.5242, 6.4204, -7.2764, 1.4699, 3.1666, -4.0820, 4.4864,
6.4716, 10.8337]])
tensor([60., 60., 60.])
第二種: 由requires_grad爲False的張量(全部都爲False)運算而來的節點:
f = a + b
print(f.is_leaf, f.grad_fn)
out:True None
這裏f並不是由用戶創建的,而是由其他張量運算而來,卻也被認定爲葉子節點了,其grad_fn也爲None。
總結一下:requires_grad爲False或是由用戶創建的張量(即使爲True)被認爲是葉子節點。在計算梯度時,requires_grad爲True的節點的梯度會被計算,然後其中非葉子節點的梯度會被清空用以節省內存,因爲一般而言,只有葉子節點的梯度纔是有價值的(藉此優化)。
import torch
a = torch.rand(5, 5, requires_grad=False)
b = torch.rand(5, 5, requires_grad=False)
c = torch.rand(5, 5, requires_grad=True)
f = torch.add(a, b)
g = b + c
f_dt = torch.detach(f)
h = f_dt + a # f_dt和a的requires_grad都爲False,故h的requires_grad爲False,爲葉子節點
y = 2 * h * g
y.backward(torch.ones_like(y))
# h是本來就沒有梯度,g是計算完梯度後被釋放了
print(h.grad, h.is_leaf, h.requires_grad, h.grad_fn)
print(g.grad, g.is_leaf, g.requires_grad, g.grad_fn)
print(y.grad_fn, y.grad_fn.next_functions)
print(c.grad, b.grad)
以下是輸出:
None True False None
None False True <AddBackward0 object at 0x0000018A5B437D68>
<MulBackward0 object at 0x0000018A5B437D68> ((None, 0), (<AddBackward0 object at 0x0000018A577320B8>, 0))
tensor([[4.5234, 4.5274, 2.3929, 4.2744, 1.5618],
[2.8044, 3.5590, 1.5690, 2.8343, 4.1017],
[2.5466, 3.9801, 4.4518, 3.5405, 3.0842],
[4.2424, 3.1682, 3.4410, 3.4308, 4.1553],
[3.3562, 0.4991, 3.6946, 0.6475, 2.7033]]) None
非葉子節點,如果我們非要獲得它的梯度,可以通過retain_grad方法或者hook。
hook:
import torch
# import torch.nn as nn
grads = {}
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
x = torch.rand(1, requires_grad=True)
y = 3 * x
z = y ** 2
y.register_hook(save_grad('y'))
z.backward()
print(grads['y'])
我們平時的網絡bp都是指對w和b進行反傳,沒有對torch.tensor的參數進行bp的。在detach之後,相當於是那個參數不需要梯度,可是我們不是本來就不需要它的梯度嗎,我們只需要卷積層參數的梯度纔是。
圖片數據轉化爲的tensor都是require_grad=False的葉子節點,但是經過requires_grad=True的卷積層後輸出的特徵圖都是requires_grad等於True(is_leaf==False)的非葉子節點了。
最後的loss在backward()後,loss的grad_fn屬性被調用,執行求導的操作。這個操作將遍歷grad_fn的next_functions,然後取出裏面的Function(AccumulateGrad),執行求導操作。這部分是一個遞歸的過程,直到最後類型爲葉子節點(即輸入的圖片或被detach分離下來的節點)。計算出結果之後,將結果保存到節點們對應的variable這個變量(儲存在loss的grad_fn的遞歸過程中)所引用的對象(節點們)的grad這個屬性裏面。接到結束,所有的非葉子節點的grad變量都得到了更新。
其實不管是特徵圖還是卷積裏面的變量,它們是requires_grad == True的,即loss在backward的時候會求出它們所有的grad。但是我們的optimizer在生成的時候是傳入了需要優化的參數model.parameter的,所以只有卷積等模型的參數的值會被更新,
關於detach函數:它會生成一個原張量新的視圖,requires_grad永遠爲False,後面的loss在backward到這裏的時候發現了葉子節點,就不會再往傳播了,所以前面的所有節點的grad值都不會受到這條支路的影響,所以自然也不會被這條支路的loss所優化。這就是detach所謂的阻斷傳播的作用。