本文爲對YouTube博主 Elliot Waite所講視頻的記錄與思考,視頻地址 PyTorch Autograd Explained - In-depth Tutorial
autograd 流程圖
例 1.
a = torch.tensor(2.0)
b = torch.tensor(3.0)
c = a*b
前項的計算圖如下:
每個方框代表一個tensor,其中列出一些屬性(還有其他很多屬性):
- .data 存了tensor的data
- .grad 當計算gradient的時候將會存入此函數對應情況下的gradient
- .grad_fn 指向用於backward的函數的節點
- .is_leaf 判斷是否是葉節點 (關於葉節點的信息請移步: one way的pytorch學習筆記(三)leaf 葉子(張量))
- .requires_grad 如果是設爲
True
,那麼在做backward時候將作爲圖的一部分參與backwards運算,如果爲False
則不參加backwards運算
在圖中可見, c= a*b的運算也算作計算圖的一部分, 用Mul表示.由於a 和b 是require_grad,所以自動的被算爲 is_leaf 爲True, 至於爲什麼請見one way的pytorch學習筆記(三)leaf 葉子(張量) ,此時由於requires_grad都爲False,因此沒有backwards的graph.
例 2.
a = torch.tensor(2.0,requires_grad = True)
b = torch.tensor(3.0)
c = a*b
c.backward()
重新進行計算,設tensor a的 requires_grad 爲True
,輸出結果c因爲輸入自變量的屬性爲True而自動改變成 requires_grad 爲True
.這說明只要自變量中有一個requires_grad 爲True
, 進一步通過運算生成的變量也爲True. 此時的c 爲非葉節點, grad_fn 指向做backwards時與當前變量相關的backwards的函數(函數爲pytorch自動生成的).
- 當我們調用tensor的乘法函數時,同時調用了隱性變量 ctx (context)變量的save_for_backward(),這樣就把此函數做backward時所需要的從forward函數中獲取的相關的一些值存到了ctx中.ctx起到了緩存相關參數的作用,變成連接forward與backward之間的緩存站. ctx中的值將會在c 做backwards時傳遞給對應的Mulbackward 操作.
- 與此同時 由於c是通過 c=a*b運算得來的, c的grad_fn中存了做backwards時候對應的函數.且把這個對應的backward 叫做 “MulBackward”
- 當進行c的backwards的時候,其實也就相當於執行了 c = a*b這個函數分別對 a 與b 做的偏導. 那麼理應對應兩組backwards的函數,這兩組backwards的函數打包存在 MulBackward的 next_functions 中. nex_function爲一個 tuple list, AccumulateGrad 將會把相應得到的結果送到 a.grad中.
- 於是在進行 c.backward() 後, c進行關於a以及關於b的求導,由於b設requires_grad爲
False
,因此b項不參與backwards運算(自然的,next_function中list的第二個tuple即爲None
),c關於a的梯度爲3,因此3將傳遞給AccumulaGrad進一步傳給a.grad 因此 a.grad 的結果爲3
從這個backward graph中,可以看出,其實pytorch在定義這些變量的運算函數時,其實也定義了函數對應的backwards的函數.如果想使用自定義的函數,那麼自己也必須要定義backwards函數.
例 3.
a = torch.tensor(2.0,requires_grad = True)
b = torch.tensor(3.0,requires_grad = True)
c = a*b
d = torch.tensor(4.0,requires_grad = True)
e = c*d
e.backward()
- e的grad_fn 指向節點 MulBackward, c的grad_fn指向另一個節點 MulBackward
- c 爲中間值is_leaf 爲
False
,因此並不包含 grad值,在backward計算中,並不需要再重新獲取c.grad的值, backward的運算直接走相應的backward node 即可 - MulBackward 從 ctx.saved_tensor中調用有用信息, e= c+d中 e關於c的梯度通過MulBackward 獲取得4. 根據鏈式規則, 4再和上一階段的 c關於 a和c關於b的兩個梯度值3和2相乘,最終得到了相應的值12 和8
- a.grad 中存入12, b.grad中存入 8