PyTorch自動微分

autograd包是pytorch中所有升級網絡的核心。autograd軟件包爲tensor上的所有操作提供自動微分。它是一個由運行定義的框架,這意味着以代碼運行方式定義你的後向傳播,並且每次迭代都可以不同。

TENSOR

跟蹤

torch.Tensor是包的核心類。如果將其屬性.requires_grad設置爲True,則會開始跟蹤針對tensor的所有操作。完成計算後,可以調用.backward()來自動計算所有梯度。該張量的梯度將累積到.grad屬性中。

停止跟蹤

停止tensor歷史記錄的跟蹤,可以調用.detach(),它將其與技術歷史記錄分離,並防止將來的計算被跟蹤。

要停止跟蹤歷史記錄和使用內存,還可以將代碼塊使用with torch.no_grad():包裝起來。在評估模型時,特別有用,因爲模型在訓練階段具有requires_grad = True的可訓練參數有利於調參,但在評估階段我們不需要梯度。

FUNCTION

還有一個類對於autograd實現非常重要的,是Function。Tensor和Function互相連接並構建一個非循環圖,它保存真個完整的計算過程的歷史信息。每個張量都有一個.grad_fn屬性保存着創建了張量的Function的引用。如果用戶自己創建張量,則grad_fn是None。

如果計算導數,可以調用Tensor.backward()。如果tensor是標量(包含一個元素數據),則不需要制定任何參數backward(),但是如果它有更多的元素,則需要指定一個gradient參數來指定張量的形狀。

例子

import torch

創建一個張量,設置requires_grad=True來跟蹤與它相關的計算

x = torch.ones(2, 2, requires_grad=True)
print(x)

輸出

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

針對張量做一個操作

y = x + 2
print(y)

輸出

tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)

y作爲操作的結果被創建,所以它有grad_fn

print(y.grad_fn)

輸出

<AddBackward0 object at 0x7fe1db427470>

針對y做更多的操作

z = y * y * 3
out = z.mean()
print(z, out)

輸出

tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)

.requires_grad_(…)會改變張量的requires_grad標記。輸入的標記默認爲False,如果沒有提供相應的參數。

a = torch.randn(2, 2)
a = ((a * 3)/(a -1))
print(a.requires_grad)
a.requires_grad_(True)
print(a.requires_grad)
b = (a*a).sum()
print(b.grad_fn)

輸出

False
True
<SumBackward0 object at 0x7fe1db427dd8>

梯度

現在後向傳播,因爲輸出包含了一個標量,out.backward()等同於out.backward(torch.tensor(1.))

out.backward()

打印梯度d(out)/dx

print(x.grad)

輸出

tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

雅可比例子

x = torch.randn(3, requires_grad=True)

y = x*2
while y.data.norm()<1000:
    y = y*2

print(y)

輸出

tensor([ -444.6791,   762.9810, -1690.0941], grad_fn=<MulBackward0>)

現在在這種情況下,y不再是一個標量。torch.autograd不能夠直接計算整個雅可比,但是如果我們只要雅可比向量積,只需要簡單的傳遞向量給backward作爲參數

v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)

print(x.grad)

輸出

tensor([1.0240e+02, 1.0240e+03, 1.0240e-01])

停止跟蹤例子

可以通過將代碼包裹在with torch.no_grad(),來停止對從跟蹤歷史中的.requires_grad=True的張量自動求導。

print(x.requires_grad)
print((x**2).requires_grad)

with torch.no_grad():
    print((x**2).requires_grad)

輸出

True
True
False

更詳細的可以訪問: https://pytorch.org/docs/autograd

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章