autograd包是pytorch中所有升級網絡的核心。autograd軟件包爲tensor上的所有操作提供自動微分。它是一個由運行定義的框架,這意味着以代碼運行方式定義你的後向傳播,並且每次迭代都可以不同。
TENSOR
跟蹤
torch.Tensor是包的核心類。如果將其屬性.requires_grad
設置爲True,則會開始跟蹤針對tensor的所有操作。完成計算後,可以調用.backward()
來自動計算所有梯度。該張量的梯度將累積到.grad屬性中。
停止跟蹤
要停止tensor歷史記錄的跟蹤,可以調用.detach()
,它將其與技術歷史記錄分離,並防止將來的計算被跟蹤。
要停止跟蹤歷史記錄和使用內存,還可以將代碼塊使用with torch.no_grad():
包裝起來。在評估模型時,特別有用,因爲模型在訓練階段具有requires_grad = True
的可訓練參數有利於調參,但在評估階段我們不需要梯度。
FUNCTION
還有一個類對於autograd實現非常重要的,是Function。Tensor和Function互相連接並構建一個非循環圖,它保存真個完整的計算過程的歷史信息。每個張量都有一個.grad_fn
屬性保存着創建了張量的Function的引用。如果用戶自己創建張量,則grad_fn
是None。
如果計算導數,可以調用Tensor.backward()
。如果tensor是標量(包含一個元素數據),則不需要制定任何參數backward(),但是如果它有更多的元素,則需要指定一個gradient參數來指定張量的形狀。
例子
import torch
創建一個張量,設置requires_grad=True來跟蹤與它相關的計算
x = torch.ones(2, 2, requires_grad=True)
print(x)
輸出
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
針對張量做一個操作
y = x + 2
print(y)
輸出
tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward0>)
y作爲操作的結果被創建,所以它有grad_fn
print(y.grad_fn)
輸出
<AddBackward0 object at 0x7fe1db427470>
針對y做更多的操作
z = y * y * 3
out = z.mean()
print(z, out)
輸出
tensor([[27., 27.],
[27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
.requires_grad_(…)會改變張量的requires_grad標記。輸入的標記默認爲False,如果沒有提供相應的參數。
a = torch.randn(2, 2)
a = ((a * 3)/(a -1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a*a).sum()
print(b.grad_fn)
輸出
False
True
<SumBackward0 object at 0x7fe1db427dd8>
梯度
現在後向傳播,因爲輸出包含了一個標量,out.backward()
等同於out.backward(torch.tensor(1.))
。
out.backward()
打印梯度d(out)/dx
print(x.grad)
輸出
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
雅可比例子
x = torch.randn(3, requires_grad=True)
y = x*2
while y.data.norm()<1000:
y = y*2
print(y)
輸出
tensor([ -444.6791, 762.9810, -1690.0941], grad_fn=<MulBackward0>)
現在在這種情況下,y不再是一個標量。torch.autograd不能夠直接計算整個雅可比,但是如果我們只要雅可比向量積,只需要簡單的傳遞向量給backward作爲參數
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)
print(x.grad)
輸出
tensor([1.0240e+02, 1.0240e+03, 1.0240e-01])
停止跟蹤例子
可以通過將代碼包裹在with torch.no_grad()
,來停止對從跟蹤歷史中的.requires_grad=True
的張量自動求導。
print(x.requires_grad)
print((x**2).requires_grad)
with torch.no_grad():
print((x**2).requires_grad)
輸出
True
True
False
更詳細的可以訪問: https://pytorch.org/docs/autograd