torch.Tensor.detach()的使用
detach()的官方說明如下:
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
假設有模型A和模型B,我們需要將A的輸出作爲B的輸入,但訓練時我們只訓練模型B. 那麼可以這樣做:
input_B = output_A.detach()
它可以使兩個計算圖的梯度傳遞斷開,從而實現我們所需的功能。