關於loss.backward()以及其參數retain_graph的一些坑

關於loss.backward()以及其參數retain_graph的一些坑

首先,loss.backward()這個函數很簡單,就是計算與圖中葉子結點有關的當前張量的梯度
使用呢,當然可以直接如下使用

    optimizer.zero_grad() 清空過往梯度;
    loss.backward() 反向傳播,計算當前梯度;
    optimizer.step() 根據梯度更新網絡參數

or這種情況
    for i in range(num):
        loss+=Loss(input,target)
    optimizer.zero_grad() 清空過往梯度;
    loss.backward() 反向傳播,計算當前梯度;
    optimizer.step() 根據梯度更新網絡參數

 但是,有些時候會出現這樣的錯誤:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

這個錯誤的意思就是,Pytorch的機制是每次調用.backward()都會free掉所有buffers,模型中可能有多次backward(),而前一次backward()存儲在buffer中的梯度,會因爲後一次調用backward()被free掉,因此,這裏需要用到retain_graph=True這個參數
使用這個參數,可以讓前一次的backward()的梯度保存在buffer內,直到更新完成,但是要注意,如果你是這樣寫的:

    optimizer.zero_grad() 清空過往梯度;
    loss1.backward(retain_graph=True) 反向傳播,計算當前梯度;
    loss2.backward(retain_graph=True) 反向傳播,計算當前梯度;
    optimizer.step() 根據梯度更新網絡參數

 那麼你可能會出現內存溢出的情況,並且,每一次迭代會比上一次更慢,越往後越慢(因爲你的梯度都保存了,沒有free)
解決的方法,當然是這樣:

    optimizer.zero_grad() 清空過往梯度;
    loss1.backward(retain_graph=True) 反向傳播,計算當前梯度;
    loss2.backward() 反向傳播,計算當前梯度;
    optimizer.step() 根據梯度更新網絡參數

即:最後一個backward()不要加retain_graph參數,這樣每次更新完成後會釋放佔用的內存,也就不會出現越來越慢的情況了。

這裏有人就會問了,我又沒有這麼多 loss,怎麼還會出現這種錯誤呢?這裏可能是因爲,你用的模型本身有問題,LSTM和GRU都會出現這樣的問題,問題存在與hidden unit,這個東東也參與了反向傳播,所以導致了有多個backward(),
這裏其實我也挺費解,爲什麼存在多個backward()呢?難道是,我的LSTM網絡是N to N,即輸入N和,輸出N個,然後和N個label進行計算loss,再進行回傳,這裏,可以思考一下BPTT,即,如果是N to 1,那麼梯度更新需要時間序列所有的輸入以及隱藏變量計算梯度,然後從最後一個向前傳,所以只有一個backward(), 而N to N 以及 N to M 都會出現多個loss需要進行backward()的情況,如果還是兩個方向(一個從輸出到輸入,一個沿着時間)一直進行傳播,那麼其實會有重疊的部分,所以解決的方法也就很明瞭了,利用detach()函數,切斷其中重疊的反向傳播,(這裏僅是我的個人理解,若有誤還請評論指出,大家共同探討)切斷的方式有三種,如下:

hidden.detach_()
hidden = hidden.detach()
hidden = Variable(hidden.data, requires_grad=True) 

任選其一即可, 這裏附一些我參考的解釋,大家可以看看

Help clarifying repackage_hidden in word_language_model
Pytorch 如何實現訓練LSTM的BPTT算法?
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed
Pytorch中retain_graph參數的作用

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章