PyTorch更新部分網絡,其他不更新

torch.Tensor.detach()的使用
detach()的官方說明如下:

Returns a new Tensor, detached from the current graph.
    The result will never require gradient.
假設有模型A和模型B,我們需要將A的輸出作爲B的輸入,但訓練時我們只訓練模型B. 那麼可以這樣做:

input_B = output_A.detach()

它可以使兩個計算圖的梯度傳遞斷開,從而實現我們所需的功能。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章