GAN網絡走過的坑

所用語言:PyTorch,python3.6

問題

1、解決方案引用[1]

復現DCGAN代碼時出現ERROR如下:
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

後在所有的loss中添加retain_graph=True,解決了該問題。
loss_d.backward(retain_graph=True)
但伴隨着出現瞭如下問題
out of memory
在gen_data後加detach(),刪除上面添加的retain_graph=True,解決了上述所有問題

d_fake = discriminator(gen_data.detach())

另:[1]中還具體介紹了GAN先G再D,與先D再G的不同

2、所有參考代碼

CDCGAN

參考文獻

[1] Pytorch: detach 和 retain_graph

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章