pytorch用法记录(torch.Storage与detach)

1.torch.Storage类

 

 

 

 使用storage()函数把Tensor数据转换为float类型的Storage数据,再使用tolist() 返回一个包含此存储中元素的列表。

2.detach 计算图截断

detach 的意思是,这个数据和生成它的计算图“脱钩”了,即detach就是截断反向传播的梯度流。GAN中,Train D on fake,G生成的数据会传入D,然后计算loss,再反向传播更新。由于backward()的操作是我们希望D(判别器)端的loss更新D但不要影响到 G(生成器)。

#  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()

如上例,对G生成的数据G(d_gen_input)执行detach()操作,判别器D梯度反向传播,就到它自己身上为止,不会继续反向传播到G。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章