【one way的pytorch學習筆記】(四)autograd的流程機制原理

本文爲對YouTube博主 Elliot Waite所講視頻的記錄與思考,視頻地址 PyTorch Autograd Explained - In-depth Tutorial


autograd 流程圖

例 1.
a = torch.tensor(2.0)
b = torch.tensor(3.0)
c = a*b 

前項的計算圖如下:
每個方框代表一個tensor,其中列出一些屬性(還有其他很多屬性):

  • .data 存了tensor的data
  • .grad 當計算gradient的時候將會存入此函數對應情況下的gradient
  • .grad_fn 指向用於backward的函數的節點
  • .is_leaf 判斷是否是葉節點 (關於葉節點的信息請移步: one way的pytorch學習筆記(三)leaf 葉子(張量))
  • .requires_grad 如果是設爲True,那麼在做backward時候將作爲圖的一部分參與backwards運算,如果爲False則不參加backwards運算

在圖中可見, c= a*b的運算也算作計算圖的一部分, 用Mul表示.由於a 和b 是require_grad,所以自動的被算爲 is_leaf 爲True, 至於爲什麼請見one way的pytorch學習筆記(三)leaf 葉子(張量) ,此時由於requires_grad都爲False,因此沒有backwards的graph.

例 2.
a = torch.tensor(2.0,requires_grad = True)
b = torch.tensor(3.0)
c = a*b 
c.backward()

重新進行計算,設tensor arequires_gradTrue,輸出結果c因爲輸入自變量的屬性爲True而自動改變成 requires_gradTrue.這說明只要自變量中有一個requires_gradTrue, 進一步通過運算生成的變量也爲True. 此時的c 爲非葉節點, grad_fn 指向做backwards時與當前變量相關的backwards的函數(函數爲pytorch自動生成的).

這是算數式的前饋流程圖, 使我們可以通過函數代碼觀察到的.在此流程圖的背後,其實pytorch還自動生成了對應的backwards的流程圖:
  1. 當我們調用tensor的乘法函數時,同時調用了隱性變量 ctx (context)變量的save_for_backward(),這樣就把此函數做backward時所需要的從forward函數中獲取的相關的一些值存到了ctx中.ctx起到了緩存相關參數的作用,變成連接forward與backward之間的緩存站. ctx中的值將會在c 做backwards時傳遞給對應的Mulbackward 操作.
  2. 與此同時 由於c是通過 c=a*b運算得來的, c的grad_fn中存了做backwards時候對應的函數.且把這個對應的backward 叫做 “MulBackward
  3. 當進行c的backwards的時候,其實也就相當於執行了 c = a*b這個函數分別對 ab 做的偏導. 那麼理應對應兩組backwards的函數,這兩組backwards的函數打包存在 MulBackward的 next_functions 中. nex_function爲一個 tuple list, AccumulateGrad 將會把相應得到的結果送到 a.grad中.
  4. 於是在進行 c.backward() 後, c進行關於a以及關於b的求導,由於brequires_gradFalse,因此b項不參與backwards運算(自然的,next_function中list的第二個tuple即爲None),c關於a的梯度爲3,因此3將傳遞給AccumulaGrad進一步傳給a.grad 因此 a.grad 的結果爲3

從這個backward graph中,可以看出,其實pytorch在定義這些變量的運算函數時,其實也定義了函數對應的backwards的函數.如果想使用自定義的函數,那麼自己也必須要定義backwards函數.

例 3.
a = torch.tensor(2.0,requires_grad = True)
b = torch.tensor(3.0,requires_grad = True)
c = a*b 
d = torch.tensor(4.0,requires_grad = True)
e = c*d
e.backward()
  1. egrad_fn 指向節點 MulBackward, cgrad_fn指向另一個節點 MulBackward
  2. c 爲中間值is_leafFalse,因此並不包含 grad值,在backward計算中,並不需要再重新獲取c.grad的值, backward的運算直接走相應的backward node 即可
  3. MulBackward 從 ctx.saved_tensor中調用有用信息, e= c+de關於c的梯度通過MulBackward 獲取得4. 根據鏈式規則, 4再和上一階段的 c關於 ac關於b的兩個梯度值3和2相乘,最終得到了相應的值12 和8
  4. a.grad 中存入12, b.grad中存入 8
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章