for epoch in range(EPOCH):
sum_D = 0
sum_G = 0
for step, (images, imagesLabel) in enumerate(train_loader):
print(step)
G_ideas = t.randn((BATCH_SIZE, Len_Z, 1, 1))
G_paintings = G(G_ideas)
prob_artist0 = D(images) # D try to increase this prob
prob_artist1 = D(G_paintings)
p0 = t.squeeze(prob_artist0)
p1 = t.squeeze(prob_artist1)
errD_real = criterion(p0, label_Real)
errD_fake = criterion(p1, label_Fake)
# errD_fake.backward()
errD = errD_fake + errD_real
errG = criterion(p1, label_Real)
sum_D=sum_D+errD.detach().numpy()
sum_G=sum_G+errG.detach().numpy()
#print("errD is %f"%errD)
#print("sumD is %f"%sum_D)
optimD.zero_grad()
errD.backward(retain_graph=True)
optimD.step()
optimG.zero_grad()
errG.backward(retain_graph=True)
optimG.step()
今天在實驗時直接使用sum_D=sum_D+errD,發現內存快速飆升。後來改成sum_D=sum_D+errD.detach().numpy(),總算沒問題了,因爲第一種表達式等於是在搭網絡節點,當然會不斷提升網絡容量,提高內存消耗量。