pytorch用法記錄(torch.Storage與detach)

1.torch.Storage類

 

 

 

 使用storage()函數把Tensor數據轉換爲float類型的Storage數據,再使用tolist() 返回一個包含此存儲中元素的列表。

2.detach 計算圖截斷

detach 的意思是,這個數據和生成它的計算圖“脫鉤”了,即detach就是截斷反向傳播的梯度流。GAN中,Train D on fake,G生成的數據會傳入D,然後計算loss,再反向傳播更新。由於backward()的操作是我們希望D(判別器)端的loss更新D但不要影響到 G(生成器)。

#  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()

如上例,對G生成的數據G(d_gen_input)執行detach()操作,判別器D梯度反向傳播,就到它自己身上爲止,不會繼續反向傳播到G。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章