當我們在訓練網絡的時候可能希望保持一部分的網絡參數不變,只對其中一部分的參數進行調整;或者只訓練部分分支網絡,並不讓其梯度對主網絡的梯度造成影響,這時候我們就需要使用detach()函數來切斷一些分支的反向傳播。
一、detach()[source]
返回一個新的Variable
,從當前計算圖中分離下來的,但是仍指向原變量的存放位置,不同之處只是requires_grad爲false,得到的這個Variable
永遠不需要計算其梯度,不具有grad。
即使之後重新將它的requires_grad置爲true,它也不會具有梯度grad
這樣我們就會繼續使用這個新的Variable進行計算,後面當我們進行反向傳播時,到該調用detach()的Variable
就會停止,不能再繼續向前進行傳播。
源碼爲:
def detach(self): """Returns a new Variable, detached from the current graph. Result will never require gradient. If the input is volatile, the output will be volatile too. .. note:: Returned Variable uses the same data tensor, as the original one, and in-place modifications on either of them will be seen, and may trigger errors in correctness checks. """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result
可見函數進行的操作有:
- 將grad_fn設置爲None
將Variable
的requires_grad設置爲False
如果輸入 volatile=True(即不需要保存記錄,當只需要結果而不需要更新參數時這麼設置來加快運算速度)
,那麼返回的Variable
。(volatile
=True
已經棄用)volatile
注意:
返回的
Variable
和原始的Variable
公用同一個data tensor
。in-place函數
修改會在兩個Variable
上同時體現(因爲它們共享data tensor
),當要對其調用backward()時可能會導致錯誤。
舉例:
一個正常的例子:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
out.sum().backward()
print(a.grad)
返回:
None
tensor([0.1966, 0.1050, 0.0452])
當使用detach()但是沒有進行更改時,並不會影響backward():
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad爲False
c = out.detach()
print(c)
#這時候沒有對c進行更改,所以並不會影響backward()
out.sum().backward()
print(a.grad)
返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
可見c,out之間的區別是c是沒有梯度的,out是有梯度的
如果這裏使用的是c進行sum()操作並進行backward(),則會報錯:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad爲False
c = out.detach()
print(c)
#使用新生成的Variable進行反向傳播
c.sum().backward()
print(a.grad)
返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
Traceback (most recent call last):
File "test.py", line 13, in <module>
c.sum().backward()
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
如果此時對c進行了更改,這個更改會被autograd追蹤,在對out.sum()進行backward()時也會報錯,因爲此時的值進行backward()得到的梯度是錯誤的:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad爲False
c = out.detach()
print(c)
c.zero_() #使用in place函數對其進行修改
#會發現c的修改同時會影響out的值
print(c)
print(out)
#這時候對c進行更改,所以會影響backward(),這時候就不能進行backward(),會報錯
out.sum().backward()
print(a.grad)
返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
Traceback (most recent call last):
File "test.py", line 16, in <module>
out.sum().backward()
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/anaconda3/envs/deeplearning/lib/python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
二、data
如果上面的操作使用的是.data,效果會不同:
這裏的不同在於.data的修改不會被autograd追蹤,這樣當進行backward()時它不會報錯,回得到一個錯誤的backward值
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
c = out.data
print(c)
c.zero_() #使用in place函數對其進行修改
#會發現c的修改同時也會影響out的值
print(c)
print(out)
#這裏的不同在於.data的修改不會被autograd追蹤,這樣當進行backward()時它不會報錯,回得到一個錯誤的backward值
out.sum().backward()
print(a.grad)
返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
tensor([0., 0., 0.])
上面的內容實現的原理是:In-place 正確性檢查
所有的
Variable
都會記錄用在他們身上的in-place operations
。如果pytorch
檢測到variable
在一個Function
中已經被保存用來backward
,但是之後它又被in-place operations
修改。當這種情況發生時,在backward
的時候,pytorch
就會報錯。這種機制保證了,如果你用了in-place operations
,但是在backward
過程中沒有報錯,那麼梯度的計算就是正確的。
下面結果正確是因爲改變的是sum()的結果,中間值a.sigmoid()並沒有被影響,所以其對求梯度並沒有影響:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid().sum() #但是如果sum寫在這裏,而不是寫在backward()前,得到的結果是正確的
print(out)
c = out.data
print(c)
c.zero_() #使用in place函數對其進行修改
#會發現c的修改同時也會影響out的值
print(c)
print(out)
#沒有寫在這裏
out.backward()
print(a.grad)
返回:
None
tensor(2.5644, grad_fn=<SumBackward0>)
tensor(2.5644)
tensor(0.)
tensor(0., grad_fn=<SumBackward0>)
tensor([0.1966, 0.1050, 0.0452])
三、 detach_()[source]
將一個Variable
從創建它的圖中分離,並把它設置成葉子variable
其實就相當於變量之間的關係本來是x -> m -> y,這裏的葉子variable是x,但是這個時候對m進行了.detach_()操作,其實就是進行了兩個操作:
- 將m的grad_fn的值設置爲None,這樣m就不會再與前一個節點x關聯,這裏的關係就會變成x, m -> y,此時的m就變成了葉子結點
- 然後會將m的requires_grad設置爲False,這樣對y進行backward()時就不會求m的梯度
這麼一看其實detach()和detach_()很像,兩個的區別就是detach_()是對本身的更改,detach()則是生成了一個新的variable
比如x -> m -> y中如果對m進行detach(),後面如果反悔想還是對原來的計算圖進行操作還是可以的
但是如果是進行了detach_(),那麼原來的計算圖也發生了變化,就不能反悔了。
轉自:https://blog.csdn.net/weixin_34363171/article/details/94236818