pytorch detach().numpy()

 

    for epoch in range(EPOCH):
        sum_D = 0
        sum_G = 0
        for step, (images, imagesLabel) in enumerate(train_loader):
            print(step)
            G_ideas = t.randn((BATCH_SIZE, Len_Z, 1, 1))

            G_paintings = G(G_ideas)
            prob_artist0 = D(images)  # D try to increase this prob
            prob_artist1 = D(G_paintings)
            p0 = t.squeeze(prob_artist0)
            p1 = t.squeeze(prob_artist1)

            errD_real = criterion(p0, label_Real)

            errD_fake = criterion(p1, label_Fake)
            # errD_fake.backward()

            errD = errD_fake + errD_real
            errG = criterion(p1, label_Real)
            sum_D=sum_D+errD.detach().numpy()
            sum_G=sum_G+errG.detach().numpy()
            #print("errD is %f"%errD)
            #print("sumD is %f"%sum_D)
            optimD.zero_grad()
            errD.backward(retain_graph=True)
            optimD.step()

            optimG.zero_grad()
            errG.backward(retain_graph=True)
            optimG.step()

今天在實驗時直接使用sum_D=sum_D+errD,發現內存快速飆升。後來改成sum_D=sum_D+errD.detach().numpy(),總算沒問題了,因爲第一種表達式等於是在搭網絡節點,當然會不斷提升網絡容量,提高內存消耗量。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章