pytorch張量複製clone()和detach()

轉自:https://blog.csdn.net/winycg/article/details/100813519

tensor複製可以使用clone()函數和detach()函數即可實現各種需求。

clone

clone()函數可以返回一個完全相同的tensor,新的tensor開闢新的內存,但是仍然留在計算圖中。

clone操作在不共享數據內存的同時支持梯度回溯,所以常用在神經網絡中某個單元需要重複使用的場景下。

detach

detach()函數可以返回一個完全相同的tensor,新的tensor開闢與舊的tensor共享內存,新的tensor會脫離計算圖,不會牽扯梯度計算。此外,一些原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在兩者任意一個執行都會引發錯誤。
detach操作在共享數據內存的脫離計算圖,所以常用在神經網絡中僅要利用張量數值,而不需要追蹤導數的場景下。

使用分析

# Operation New/Shared memory Still in computation graph
tensor.clone() New Yes
tensor.detach() Shared No
tensor.clone().detach() New No

如下執行一些實例:
首先導入包並固定隨機種子

import torch
torch.manual_seed(0)

1.clone()之後的tensor requires_grad=True,detach()之後的tensor requires_grad=False,但是梯度並不會流向clone()之後的tensor

x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.clone().detach()

f = torch.nn.Linear(3, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)

輸出:

tensor([-0.0043,  0.3097, -0.4752])
True
None
False
False

2.將計算圖中參與運算tensor變爲clone()後的tensor。此時梯度仍然只流向了原始的tensor。

x= torch.tensor([1., 2., 3.], requires_grad=True)
clone_x = x.clone()
detach_x = x.detach()
clone_detach_x = x.detach().clone()

f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()

print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)

輸出:

tensor([-0.0043,  0.3097, -0.4752])
None
False
False

3.將原始tensor設爲requires_grad=False,clone()後的梯度設爲.requires_grad_(),clone()後的tensor參與計算圖的運算,則梯度傳向clone()前的tensor。

x= torch.tensor([1., 2., 3.], requires_grad=False)
clone_x = x.clone().requires_grad_()
detach_x = x.detach()
clone_detach_x = x.detach().clone()

f = torch.nn.Linear(3, 1)
y = f(clone_x)
y.backward()

print(x.grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)

輸出:

tensor([-0.0043,  0.3097, -0.4752])
None
False
False

4.detach()後的tensor由於與原始tensor共享內存,所以原始tensor在計算圖中數值反向傳播更新之後,detach()的tensor值也發生了改變。

x = torch.tensor([1., 2., 3.], requires_grad=True)
f = torch.nn.Linear(3, 1)
w = f.weight.detach()
print(f.weight)
print(w)

y = f(x)
y.backward()

optimizer = torch.optim.SGD(f.parameters(), 0.1)
optimizer.step()

print(f.weight)
print(w)

輸出:

Parameter containing:
tensor([[-0.0043,  0.3097, -0.4752]], requires_grad=True)
tensor([[-0.0043,  0.3097, -0.4752]])
Parameter containing:
tensor([[-0.1043,  0.1097, -0.7752]], requires_grad=True)
tensor([[-0.1043,  0.1097, -0.7752]])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章