pytorch之in-place operation #含義 #代碼示例 #兩種情況不能使用inplace operation

一、in-place含義

in-place operation在pytorch中是指改變一個tensor的值的時候,不經過複製操作,而是直接在原來的內存上改變它的值。可以稱之爲“原地操作符”。

注意:PyTorch操作inplace版本都有後綴"_", 例如y.add_(x),x.copy_(y),x.t_()

python裏面的+=*=也是in-place operation

如果你使用了in-place operation而沒有報錯的話,那麼你可以確定你的梯度計算是正確的。

二、in-place代碼示例

import torch

x = torch.rand(5, 3)
y = torch.rand(5, 3)

# 加法形式一:+
print(x + y)

# 加法形式二:add
print(torch.add(x, y))
# add還可指定輸出
result = torch.empty(5, 3)
torch.add(x, y, out=result)
print(result)

# 加法形式三:inplace
y.add_(x) # adds x to y
print(y)

三、在pytorch中, 有兩種情況不能使用inplace operation

1、對於requires_grad=True的葉子張量(leaf tensor) 不能使用 inplace operation

2、對於在求梯度階段需要用到的張量不能使用 inplace operation

第一種情況: requires_grad=True 的 leaf tensor

import torch

w = torch.FloatTensor(10) # w 是個 leaf tensor
w.requires_grad = True    # 將 requires_grad 設置爲 True
w.normal_()               # 執行這句話就會報錯

在這裏插入圖片描述
報錯信息爲:
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
——在inplace operation中使用了需要grad的葉子變量

對比:requires_grad=False 的 leaf tensor

import torch

w = torch.FloatTensor(10) # w 是個 leaf tensor
# 默認requires_grad=False
print(w)
print(w.normal_())   

在這裏插入圖片描述
附1: pytorch函數之torch.normal()

  • Returns a Tensor of random numbers drawn from separate normal distributions who’s mean and standard deviation are given.
  • 官網給出的解釋,意思返回一個張量,張量裏面的隨機數是從相互獨立的正態分佈中隨機生成的。

第二種情況: 求梯度階段需要用到的張量

import torch
x = torch.FloatTensor([[1., 2.]])    # print(x.shape)--> torch.Size([1, 2])
w1 = torch.FloatTensor([[2.], [1.]]) # torch.Size([2, 1])
w2 = torch.FloatTensor([3.]) # torch.Size([1])

w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1)  # tensor([[4.]], grad_fn=<MmBackward>) torch.Size([1, 1])
f = torch.matmul(d, w2)  # tensor([12.], grad_fn=<MvBackward>) torch.Size([1])
d[:] = 1 # 因爲這句,代碼會報錯
f.backward()

在這裏插入圖片描述
報錯信息爲:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
——梯度計算所需的變量之一已通過in-place operation進行了修改
原因:
f=matmul(d,w2)fw2=g(d)f=\operatorname{matmul}(d, w 2),\frac{\partial f}{\partial w 2}=g(d)

  • 在計算ff的時候,dd是等於某個值的,ff對於w2w 2的導數是和這時候的dd值相關的;
  • 但是計算完ff之後,dd的值變了,這就會導致f.backward()對於w2w 2的導數計算出錯因而報錯;
  • 造成這個問題的主要原因是:在執行 f = torch.matmul(d, w2)時,pytorch的反向求導機制保存了dd,爲了之後的反向求導計算。

這樣修改就沒有問題了:
在改變dd之後再對ff進行運算操作
在這裏插入圖片描述
附1: f.backward(),默認只對w2w 2求梯度原因:
f.backward(parameters)接受的參數parameters必須要和f的大小一模一樣,然後作爲f的係數傳回去。
如果設定只傳入dd,報錯如下:
在這裏插入圖片描述
附2: pytorch函數之torch.matmul()
矩陣相乘有torch.mm和torch.matmul兩個函數。其中前一個是針對二維矩陣,後一個是高維。當torch.mm用於大於二維時將報錯。


參考鏈接:https://zhuanlan.zhihu.com/p/38475183

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章