Pytorch實現一個用於學習正態分佈的GAN網絡

GAN 很複雜?如何用不到 50 行代碼訓練 GAN(基於 PyTorch)

       2014 年,Ian Goodfellow 和他在蒙特利爾大學的同事發表了一篇震撼學界的論文《Generative Adversarial Nets》,這標誌着生成對抗網絡(GAN)的誕生,而這是通過對計算圖和博弈論的創新性結合。研究顯示:給定充分的建模能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。這兩個模型的角色定位十分鮮明。給定真實數據集 R,G 是生成器(generator),它的任務是生成能以假亂真的假數據;而 D 是判別器 (discriminator),它從真實數據集或者 G 那裏獲取數據, 然後做出判別真假的標記。Ian Goodfellow 的比喻是,G 就像一個贗品作坊,想要讓做出來的東西儘可能接近真品,矇混過關。而 D 就是文物鑑定專家,要能區分出真品和高仿(但在這個例子中,造假者 G 看不到原始數據,而只有 D 的鑑定結果——前者是在盲幹)。

       理想情況下,D 和 G 都會隨着不斷訓練,做得越來越好——直到 G 基本上成爲了一個“贗品製造大師”,而 D 因無法正確區分兩種數據分佈輸給 G。實踐中,Ian Goodfellow 展示的GAN在本質上是:G 能夠對原始數據集進行一種無監督學習,找到以更低維度的方式(lower-dimensional manner)來表示數據的某種方法。

 下面通過實現一個GAN網絡學習正態分佈來了解GAN。

 首先導入pytorch依賴的庫。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal
#定義正態分佈
#它的均值和標準差如下
data_mean = 3.0 
data_stddev = 0.4
Series_Length = 30

#定義生成網絡(Generator)
#接收一些隨機輸入,按照上面的定義生成正態分佈
g_input_size = 20    
g_hidden_size = 150  
g_output_size = Series_Length

#定義對抗網絡(Adversarial)
#True(1.0) 如果輸入的數據符合定義的正態分佈; False(0.0) 如果輸入的數據不符合定義的正態分佈
d_input_size = Series_Length
d_hidden_size = 75   
d_output_size = 1

#定義數據輸入方式
d_minibatch_size = 15 
g_minibatch_size = 10
num_epochs = 5000
print_interval = 1000

#定義學習率(learning rate)
d_learning_rate = 3e-3
g_learning_rate = 8e-3
#以下兩個函數一個可以得到真正的分佈,一個可以得到噪聲。
#真正的分佈用來訓練 Discriminator,噪聲用來作爲 Generator的輸入

def get_real_sampler(mu, sigma):
    dist = Normal( mu, sigma )
    return lambda m, n: dist.sample( (m, n) ).requires_grad_()  
 
def get_noise_sampler():
    return lambda m, n: torch.rand(m, n).requires_grad_()  # Uniform-dist data into generator, _NOT_ Gaussian
 
actual_data = get_real_sampler( data_mean, data_stddev )
noise_data  = get_noise_sampler()
# 簡單的4層網絡的生成器用來輸出符合我們想要的正態分佈的均值。

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.xfer = torch.nn.SELU()
    def forward(self, x):
        x = self.xfer( self.map1(x) )
        x = self.xfer( self.map2(x) )
        return self.xfer( self.map3( x ) )
#鑑別器(Discriminator)簡單的Linear模型,返回True或者False 
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.elu = torch.nn.ELU()
 
    def forward(self, x):
        x = self.elu(self.map1(x))
        x = self.elu(self.map2(x))
        return torch.sigmoid( self.map3(x) )
#搭建網絡,使用SGD優化函數和BCE損失函數

G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)
 
criterion = nn.BCELoss() 
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) 
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate )
#Train

def train_D_on_actual():
    real_data = actual_data(d_minibatch_size, d_input_size)
    decision = D(real_data)
    error = criterion(decision, torch.ones(d_minibatch_size, 1))  # ones = true
    error.backward()
    
def train_D_on_generated() :
    noise = noise_data(d_minibatch_size, g_input_size)
    fake_data = G(noise) 
    decision = D(fake_data)
    error = criterion(decision, torch.zeros(d_minibatch_size, 1))  # zeros = fake
    error.backward()


def train_G():
    noise = noise_data(g_minibatch_size, g_input_size)
    fake_data = G(noise)
    fake_decision = D(fake_data)
    error = criterion(fake_decision, torch.ones(g_minibatch_size, 1))  # we want to fool, so pretend it's all genuine

    error.backward()
    return error.item(), fake_data

losses = []

for epoch in range(num_epochs):
    D.zero_grad()
    
    train_D_on_actual()    
    train_D_on_generated()
    d_optimizer.step()
    
    G.zero_grad()
    loss,generated = train_G()
    g_optimizer.step()
    
    losses.append( loss )
    if( epoch % print_interval) == (print_interval-1):
        print( "Epoch %6d. Loss %5.3f" % ( epoch+1, loss ) )
        
print( "Training complete" )

Output:

Epoch   1000. Loss 0.630
Epoch   2000. Loss 0.693
Epoch   3000. Loss 0.699
Epoch   4000. Loss 0.695
Epoch   5000. Loss 0.711
Training complete

結果展示:

import matplotlib.pyplot as plt
def draw( data ) :    
    plt.figure()
    d = data.tolist() if isinstance(data, torch.Tensor ) else data
    plt.plot( d ) 
    plt.show()
d = torch.empty( generated.size(0), 53 ) 
for i in range( 0, d.size(0) ) :
    d[i] = torch.histc( generated[i], min=0, max=5, bins=53 )
draw( d.t() )

 

 

根據參考資料1訓練的正態分佈結果如下所示:

Epoch 0: D (0.824586033821106 real_err, 0.5667324066162109 fake_err) G (0.8358850479125977 err); Real Dist ([4.04999169665575, 1.1961469779141327]),  Fake Dist ([-0.7985607595443726, 0.029085312190181612]) 
Epoch 100: D (0.6598804593086243 real_err, 0.650113582611084 fake_err) G (0.7386993169784546 err); Real Dist ([3.992502051591873, 1.2512647908868664]),  Fake Dist ([3.3269334855079653, 0.04917005677031684]) 
Epoch 200: D (0.5211157202720642 real_err, 0.48223015666007996 fake_err) G (0.9489601254463196 err); Real Dist ([4.063567101441324, 1.223350898628472]),  Fake Dist ([4.171090561866761, 0.08437742403922882]) 
Epoch 300: D (0.6931817531585693 real_err, 0.6667982935905457 fake_err) G (0.7198074460029602 err); Real Dist ([3.9114013223052027, 1.1737565311661027]),  Fake Dist ([5.706633312940598, 2.3018437903116187]) 
Epoch 400: D (0.6633034944534302 real_err, 0.6330634951591492 fake_err) G (0.7584028840065002 err); Real Dist ([4.023453078389168, 1.173117767184327]),  Fake Dist ([6.466390012741089, 1.767663960120162]) 
Epoch 500: D (0.6556934118270874 real_err, 0.6795700192451477 fake_err) G (0.6501794457435608 err); Real Dist ([3.973978979229927, 1.2705711235753019]),  Fake Dist ([3.3727304997444154, 1.7591466849210384]) 
Epoch 600: D (0.5973175764083862 real_err, 0.5650959610939026 fake_err) G (0.7776590585708618 err); Real Dist ([3.8488472831845284, 1.1891462464838336]),  Fake Dist ([9.568378679275513, 0.4357023789663153]) 
Epoch 700: D (0.6975439786911011 real_err, 0.7180769443511963 fake_err) G (0.6462653279304504 err); Real Dist ([3.932738121032715, 1.224346374745391]),  Fake Dist ([3.466270624160767, 1.48932632524634]) 
Epoch 800: D (0.6946223378181458 real_err, 0.7050127983093262 fake_err) G (0.680088460445404 err); Real Dist ([4.02625471547246, 1.29204413337018]),  Fake Dist ([4.613226325511932, 0.5539247818601367]) 
Epoch 900: D (0.44320639967918396 real_err, 0.23536285758018494 fake_err) G (1.3666585683822632 err); Real Dist ([3.9780734790563583, 1.1573922540562789]),  Fake Dist ([4.808220813274383, 1.064950020735757]) 
Epoch 1000: D (0.885653555393219 real_err, 0.6489665508270264 fake_err) G (0.7260928153991699 err); Real Dist ([4.0037252243161205, 1.2868414641703296]),  Fake Dist ([4.627316523551941, 1.5317822031475823]) 
Epoch 1100: D (0.6624018549919128 real_err, 0.7288448810577393 fake_err) G (0.6776896119117737 err); Real Dist ([3.9965664629936217, 1.2443392558363737]),  Fake Dist ([4.027970834732056, 1.2487745961745749]) 
Epoch 1200: D (0.7081807851791382 real_err, 0.6922528147697449 fake_err) G (0.694278359413147 err); Real Dist ([3.8926908799931406, 1.2083719715743726]),  Fake Dist ([4.002916200637817, 1.2386345105986898]) 
Epoch 1300: D (0.697959303855896 real_err, 0.6919403076171875 fake_err) G (0.6800123453140259 err); Real Dist ([3.938853431105614, 1.2639616232174713]),  Fake Dist ([4.216027449607849, 1.1380480721612078]) 
Epoch 1400: D (0.6784661412239075 real_err, 0.6617948412895203 fake_err) G (0.6870113611221313 err); Real Dist ([3.975500011086464, 1.2855628816411546]),  Fake Dist ([4.137257863044739, 1.1094145415740049]) 
Epoch 1500: D (0.6794069409370422 real_err, 0.7028912305831909 fake_err) G (0.7014129161834717 err); Real Dist ([4.00296139895916, 1.271967197931525]),  Fake Dist ([3.9285888566970826, 1.1982835178367897]) 
Epoch 1600: D (0.6860584020614624 real_err, 0.7060790657997131 fake_err) G (0.6802226901054382 err); Real Dist ([4.058289145439863, 1.361246466914028]),  Fake Dist ([4.036336513757706, 1.2252188340907701]) 
Epoch 1700: D (0.7053802609443665 real_err, 0.696638822555542 fake_err) G (0.6959577202796936 err); Real Dist ([4.022835171103478, 1.2827733756994895]),  Fake Dist ([4.063137277841568, 1.3042375271806848]) 
Epoch 1800: D (0.6980652213096619 real_err, 0.6925143599510193 fake_err) G (0.694028913974762 err); Real Dist ([3.89647038769722, 1.1976831036253377]),  Fake Dist ([4.048781845808029, 1.2693529337960752]) 
Epoch 1900: D (0.6892062425613403 real_err, 0.6972301006317139 fake_err) G (0.6869157552719116 err); Real Dist ([4.033831748008728, 1.186610768225342]),  Fake Dist ([4.022395441293717, 1.2413968482619808]) 
Epoch 2000: D (0.6959377527236938 real_err, 0.6776554584503174 fake_err) G (0.6982696056365967 err); Real Dist ([4.000252573490143, 1.223225489126993]),  Fake Dist ([3.9199998137950898, 1.3520746107193355]) 
Epoch 2100: D (0.6903250813484192 real_err, 0.6950635313987732 fake_err) G (0.6896666288375854 err); Real Dist ([4.090404322504997, 1.2083467717828416]),  Fake Dist ([3.9939954011440277, 1.2462546765450804]) 
Epoch 2200: D (0.6896629333496094 real_err, 0.6950464248657227 fake_err) G (0.6950165033340454 err); Real Dist ([3.990626331356354, 1.2201043116027737]),  Fake Dist ([4.001183384895325, 1.2655219236978008]) 
Epoch 2300: D (0.6915993690490723 real_err, 0.6974573135375977 fake_err) G (0.6993840336799622 err); Real Dist ([4.023877061843872, 1.321507703580412]),  Fake Dist ([3.9179794733524322, 1.1990433265674747]) 
Epoch 2400: D (0.6870835423469543 real_err, 0.692215621471405 fake_err) G (0.6913619637489319 err); Real Dist ([3.921516800969839, 1.2407385883672248]),  Fake Dist ([4.060535876750946, 1.1816568966425942]) 
Epoch 2500: D (0.6934722065925598 real_err, 0.6907936334609985 fake_err) G (0.6907175779342651 err); Real Dist ([3.9839096758961676, 1.2733894018432854]),  Fake Dist ([3.990558073759079, 1.3400137248202226]) 
Epoch 2600: D (0.6947907209396362 real_err, 0.6915742754936218 fake_err) G (0.6919546127319336 err); Real Dist ([4.0130539444088935, 1.240739679033649]),  Fake Dist ([4.084916990995407, 1.2655242770514346]) 
Epoch 2700: D (0.6968568563461304 real_err, 0.6744566559791565 fake_err) G (0.6968263387680054 err); Real Dist ([3.996440895199776, 1.1986359766448267]),  Fake Dist ([3.9793724551200866, 1.1482623145157824]) 
Epoch 2800: D (0.693473756313324 real_err, 0.697898805141449 fake_err) G (0.7002295255661011 err); Real Dist ([4.00081592977047, 1.2474840907225078]),  Fake Dist ([3.895143656253815, 1.2719435213827661]) 
Epoch 2900: D (0.6950169801712036 real_err, 0.6947277188301086 fake_err) G (0.6900919675827026 err); Real Dist ([4.0483872441053395, 1.232666877408606]),  Fake Dist ([3.9184666118621827, 1.2345984703674788]) 
Epoch 3000: D (0.6994454860687256 real_err, 0.6960627436637878 fake_err) G (0.6918361186981201 err); Real Dist ([4.078680800318718, 1.2215170294711815]),  Fake Dist ([4.022074975967407, 1.1967591283090955]) 
Epoch 3100: D (0.6898576617240906 real_err, 0.6938331127166748 fake_err) G (0.6958847641944885 err); Real Dist ([3.906800366342068, 1.3110840468016158]),  Fake Dist ([3.951125014066696, 1.2253583646427406]) 
Epoch 3200: D (0.694175660610199 real_err, 0.6946524977684021 fake_err) G (0.6921048760414124 err); Real Dist ([3.958098252296448, 1.248056967946781]),  Fake Dist ([4.001224509239197, 1.1983827779563796]) 
Epoch 3300: D (0.6922207474708557 real_err, 0.6947858333587646 fake_err) G (0.6927611231803894 err); Real Dist ([3.8829670441150665, 1.1155963206788106]),  Fake Dist ([4.046220509767532, 1.1753880201920783]) 
Epoch 3400: D (0.6900198459625244 real_err, 0.6953887939453125 fake_err) G (0.6889302134513855 err); Real Dist ([4.065795364975929, 1.213252057901399]),  Fake Dist ([3.9650485837459564, 1.2672685373911108]) 
Epoch 3500: D (0.6929183006286621 real_err, 0.695112943649292 fake_err) G (0.6910595297813416 err); Real Dist ([3.8757613455876707, 1.2584639089844412]),  Fake Dist ([3.9378762912750243, 1.190689370381666]) 
Epoch 3600: D (0.693882942199707 real_err, 0.6944809556007385 fake_err) G (0.6939374804496765 err); Real Dist ([4.123380273818969, 1.2824410770958474]),  Fake Dist ([4.010826068401337, 1.2212849080025636]) 
Epoch 3700: D (0.6974205374717712 real_err, 0.6935890913009644 fake_err) G (0.6917606592178345 err); Real Dist ([4.021837902694941, 1.28027136741628]),  Fake Dist ([4.034767779827118, 1.3234349547715394]) 
Epoch 3800: D (0.6955257654190063 real_err, 0.6945357322692871 fake_err) G (0.6925974488258362 err); Real Dist ([4.12936865234375, 1.2460711614374878]),  Fake Dist ([4.0321874620914455, 1.2769764427346884]) 
Epoch 3900: D (0.6915967464447021 real_err, 0.6909477114677429 fake_err) G (0.6927705407142639 err); Real Dist ([4.0268408809900285, 1.2063883800130077]),  Fake Dist ([4.052658556222916, 1.2281464364273882]) 
Epoch 4000: D (0.6922350525856018 real_err, 0.6925557255744934 fake_err) G (0.6929080486297607 err); Real Dist ([4.021845901966095, 1.2925729163376942]),  Fake Dist ([4.025567240476608, 1.1972493940735556]) 
Epoch 4100: D (0.6933664679527283 real_err, 0.6933354139328003 fake_err) G (0.6938549280166626 err); Real Dist ([3.989677229881287, 1.2065878529125207]),  Fake Dist ([4.0539454262256625, 1.2933863721439718]) 
Epoch 4200: D (0.6897806525230408 real_err, 0.6932942867279053 fake_err) G (0.6924738883972168 err); Real Dist ([3.9900609830617904, 1.271517711087724]),  Fake Dist ([3.9614700605869295, 1.2921971453653849]) 
Epoch 4300: D (0.6924872398376465 real_err, 0.6926604509353638 fake_err) G (0.6937258839607239 err); Real Dist ([4.0992556612789635, 1.2569412389872303]),  Fake Dist ([4.127795008897781, 1.2884594395504811]) 
Epoch 4400: D (0.6946849822998047 real_err, 0.6911969184875488 fake_err) G (0.6942746639251709 err); Real Dist ([4.076893085479736, 1.2744374182411338]),  Fake Dist ([3.969561124563217, 1.2501441583969877]) 
Epoch 4500: D (0.6914315819740295 real_err, 0.6935481429100037 fake_err) G (0.6912876963615417 err); Real Dist ([3.94799566257, 1.2232376272767607]),  Fake Dist ([4.002278556823731, 1.2505587333284056]) 
Epoch 4600: D (0.7037312388420105 real_err, 0.6909893155097961 fake_err) G (0.6984265446662903 err); Real Dist ([3.9354238008633255, 1.2836771049555928]),  Fake Dist ([3.9731800141334532, 1.2964666046336777]) 
Epoch 4700: D (0.6944875121116638 real_err, 0.6940887570381165 fake_err) G (0.6929762363433838 err); Real Dist ([4.047859039783478, 1.190815309623828]),  Fake Dist ([3.990305748939514, 1.2529010827404359]) 
Epoch 4800: D (0.6942726373672485 real_err, 0.692735493183136 fake_err) G (0.6925361752510071 err); Real Dist ([3.9683418440818787, 1.2123594456412528]),  Fake Dist ([4.032746832132339, 1.213276069713684]) 
Epoch 4900: D (0.6940117478370667 real_err, 0.6929702758789062 fake_err) G (0.6935417652130127 err); Real Dist ([3.947842192411423, 1.2553405683297756]),  Fake Dist ([4.038874376773834, 1.2395540432380345]) 
Plotting the generated distribution...
 Values: [4.539769649505615, 1.63444983959198, 4.4557600021362305, 5.5528364181518555, 4.340044021606445, 2.198087453842163, 2.9329943656921387, 2.3880555629730225, 1.5384184122085571, 4.269472122192383, 4.8173933029174805, 4.85664176940918, 1.6450153589248657, 3.9355251789093018, 5.932168960571289, 1.3862096071243286, 3.9333181381225586, 3.1726975440979004, 3.696260452270508, 3.052525520324707, 3.3560595512390137, 3.0390870571136475, 2.6469674110412598, 4.103339195251465, 4.503757476806641, 4.599853038787842, 4.138694763183594, 4.259690284729004, 4.146164417266846, 1.6029934883117676, 5.843445777893066, 2.2497828006744385, 3.3378164768218994, 1.9207507371902466, 2.179079532623291, 4.690481185913086, 4.0812296867370605, 4.095373630523682, 3.8689160346984863, 6.373284816741943, 4.976202964782715, 2.3788414001464844, 4.4080915451049805, 4.072725772857666, 3.577080726623535, 1.4012407064437866, 3.4651074409484863, 1.9623477458953857, 2.1658592224121094, 4.821807861328125, 5.086670875549316, 4.142153739929199, 4.716959476470947, 6.827563285827637, 2.624016046524048, 4.119378089904785, 4.425724983215332, 2.424050807952881, 4.140542030334473, 4.075761795043945, 1.755760908126831, 1.417391300201416, 2.0366601943969727, 4.73690128326416, 5.473508834838867, 1.9511061906814575, 1.465293526649475, 4.018528938293457, 3.181847095489502, 4.6452250480651855, 3.0562477111816406, 4.425663948059082, 4.1797943115234375, 3.442857503890991, 4.425789833068848, 4.443751335144043, 3.9827144145965576, 4.548138618469238, 6.538724899291992, 4.286013603210449, 5.2574357986450195, 4.251628875732422, 3.104124069213867, 4.130465984344482, 5.224208354949951, 5.200654983520508, 4.489740371704102, 4.862637519836426, 4.596247673034668, 4.180943965911865, 1.4127655029296875, 4.495588779449463, 4.70436429977417, 5.339097023010254, 4.162871360778809, 3.3220300674438477, 3.5057907104492188, 6.428955078125, 4.321961879730225, 3.730902671813965, 2.5903165340423584, 4.3856611251831055, 2.466325044631958, 4.158637046813965, 4.462804317474365, 2.2380805015563965, 6.760380744934082, 6.239953994750977, 6.000397682189941, 3.1869945526123047, 3.6029326915740967, 4.3395891189575195, 4.260036945343018, 5.334171295166016, 4.524747848510742, 2.57928204536438, 4.054173946380615, 4.614095687866211, 4.792035102844238, 4.328861236572266, 2.879300832748413, 4.3328118324279785, 4.333817481994629, 4.411715507507324, 4.565466403961182, 5.426932334899902, 2.719020366668701, 4.017667770385742, 1.9505842924118042, 3.4826974868774414, 5.189241409301758, 3.940156936645508, 4.229568958282471, 5.3276214599609375, 2.97536301612854, 4.3470964431762695, 4.007474899291992, 4.351943016052246, 5.556241989135742, 5.733500003814697, 3.7382421493530273, 4.794580459594727, 3.421999454498291, 7.074246406555176, 4.084506988525391, 3.7553157806396484, 3.855440616607666, 1.9251701831817627, 1.8792188167572021, 4.106200695037842, 6.982327461242676, 4.308199882507324, 4.918728828430176, 3.0963966846466064, 4.6833295822143555, 6.2155351638793945, 4.364574909210205, 4.047039031982422, 6.688213348388672, 4.255982398986816, 3.465512990951538, 4.640872001647949, 4.323538780212402, 3.897338390350342, 4.843024730682373, 4.405055046081543, 4.61812162399292, 4.554985046386719, 2.2624690532684326, 3.297863245010376, 4.553571701049805, 6.009780406951904, 4.676636219024658, 3.8342180252075195, 4.457300662994385, 4.081836700439453, 3.9818239212036133, 2.5001485347747803, 4.403700828552246, 3.586491107940674, 4.333785057067871, 4.095248222351074, 4.599881172180176, 4.317112922668457, 3.5421624183654785, 1.306177020072937, 4.3163161277771, 4.383136749267578, 6.219839572906494, 4.408424377441406, 3.992192268371582, 4.691045761108398, 2.5535831451416016, 3.9267373085021973, 6.662070274353027, 5.228813648223877, 4.331976413726807, 4.020612716674805, 4.450967788696289, 4.163876533508301, 5.478597640991211, 4.078657627105713, 4.090933799743652, 2.7492101192474365, 2.77490496635437, 6.481786727905273, 3.1561057567596436, 1.6670697927474976, 3.084937572479248, 3.8840131759643555, 4.455242156982422, 4.475409507751465, 4.464001655578613, 4.2673821449279785, 3.9205546379089355, 1.739898443222046, 4.940193176269531, 3.888833522796631, 4.795442581176758, 3.1650333404541016, 5.506227493286133, 1.7203617095947266, 2.949620246887207, 4.330982208251953, 3.9301090240478516, 3.6949357986450195, 2.6255245208740234, 4.946829795837402, 4.48328161239624, 6.919548034667969, 4.124054431915283, 3.8777058124542236, 3.122709035873413, 3.371814250946045, 4.4012603759765625, 3.430072069168091, 4.387892246246338, 4.785408020019531, 4.445903778076172, 1.794466495513916, 1.3930325508117676, 3.5082497596740723, 2.6178057193756104, 3.084986686706543, 4.436977863311768, 2.976431131362915, 3.308137893676758, 4.502436637878418, 4.193325042724609, 4.551983833312988, 4.573895454406738, 3.0142345428466797, 4.624921798706055, 3.4331541061401367, 4.483516216278076, 3.5073108673095703, 1.7782986164093018, 4.087821006774902, 4.422914505004883, 1.781771183013916, 3.595541000366211, 1.3599108457565308, 4.239243507385254, 5.854596138000488, 3.2330665588378906, 3.962944507598877, 2.631474256515503, 3.1121480464935303, 4.387513160705566, 5.491522789001465, 2.666893482208252, 3.6849842071533203, 3.743051767349243, 4.33473014831543, 4.1382317543029785, 2.630114793777466, 4.073785305023193, 4.4156599044799805, 3.0505526065826416, 3.0343894958496094, 4.244548320770264, 3.556633949279785, 4.397510051727295, 5.04244327545166, 4.884440898895264, 1.4465062618255615, 3.354637622833252, 4.454654216766357, 1.3222016096115112, 2.4021379947662354, 5.0598859786987305, 5.766423225402832, 6.0669660568237305, 4.038768768310547, 6.070646286010742, 4.291622161865234, 2.593209743499756, 5.054876327514648, 6.4829912185668945, 2.1875972747802734, 4.108073711395264, 2.046146869659424, 2.195040702819824, 5.404623031616211, 3.6802978515625, 4.330349922180176, 4.543405532836914, 1.7148712873458862, 2.423985004425049, 3.5847182273864746, 5.643503189086914, 4.613306999206543, 6.315460205078125, 4.532344341278076, 2.582831621170044, 3.5568394660949707, 5.1542792320251465, 1.586488962173462, 5.473391532897949, 3.9910194873809814, 3.9006848335266113, 7.0849409103393555, 1.5694514513015747, 6.967430114746094, 3.9287962913513184, 4.961460113525391, 4.553676128387451, 4.890722274780273, 4.163438320159912, 4.260256767272949, 4.291888236999512, 3.651801109313965, 3.8796298503875732, 2.756232261657715, 3.793423652648926, 6.770336151123047, 1.4255211353302002, 5.299874305725098, 4.223451137542725, 1.5071499347686768, 6.17253303527832, 2.4757025241851807, 6.378934860229492, 4.166093826293945, 4.059502124786377, 7.021824836730957, 4.003268241882324, 4.963767051696777, 4.338545322418213, 2.8597476482391357, 6.183218955993652, 5.326067924499512, 4.450697422027588, 3.62861967086792, 3.7150216102600098, 4.301961898803711, 6.748984336853027, 4.184115409851074, 2.84629225730896, 1.594874382019043, 4.019465446472168, 3.8107612133026123, 4.758617877960205, 4.399703025817871, 3.3647310733795166, 3.88543701171875, 4.270716667175293, 5.62855339050293, 1.5347919464111328, 4.474246978759766, 3.4855237007141113, 3.5557379722595215, 5.3796892166137695, 4.433393955230713, 1.9664214849472046, 4.3585004806518555, 4.353857040405273, 3.5663862228393555, 4.252057075500488, 5.551637649536133, 4.304230213165283, 4.343073844909668, 2.065692901611328, 5.752276420593262, 2.9766857624053955, 3.2973098754882812, 4.038196563720703, 2.5002596378326416, 4.453817844390869, 7.091517448425293, 5.085714340209961, 5.929409027099609, 3.247377872467041, 4.521625995635986, 4.168796539306641, 4.012208461761475, 4.6597089767456055, 3.359983444213867, 2.857126235961914, 4.996077537536621, 6.846553802490234, 1.9788612127304077, 3.4982848167419434, 4.366124629974365, 5.133428573608398, 2.32782244682312, 4.406688213348389, 4.149277210235596, 5.970331192016602, 1.6854420900344849, 4.089301109313965, 6.799180507659912, 2.3591437339782715, 3.7117772102355957, 4.3132476806640625, 5.110054969787598, 4.329547882080078, 4.5677056312561035, 2.250425338745117, 4.502649784088135, 4.357370853424072, 4.430999755859375, 4.155341148376465, 1.4976725578308105, 6.601097106933594, 2.1156344413757324, 4.512895107269287, 5.881471633911133, 3.121486186981201, 2.249671697616577, 4.071201324462891, 4.395264625549316, 2.700623035430908, 4.496452331542969, 4.578899383544922, 3.5953516960144043, 3.5068368911743164, 6.759617805480957, 4.445155143737793, 4.355457782745361, 6.24528694152832, 4.01334810256958, 3.1470675468444824, 3.2819900512695312, 4.4579925537109375, 5.387111663818359, 3.236236572265625, 3.267516851425171, 3.948960781097412, 1.4810971021652222, 3.5609662532806396, 4.5628814697265625, 5.071683406829834, 2.4318501949310303, 3.524062156677246, 4.630815505981445, 3.383183479309082, 3.4454879760742188, 4.138335704803467, 4.55290412902832, 1.575202226638794, 3.3255257606506348, 4.816848278045654, 4.594959735870361, 5.606831073760986, 3.6404035091400146, 3.7511141300201416, 3.4937477111816406, 3.108394145965576, 6.04863166809082, 2.817197561264038, 4.1551737785339355, 3.9247243404388428, 2.2736308574676514, 2.3992624282836914, 5.064456939697266, 3.6719045639038086, 4.416315078735352, 2.936709403991699, 4.053696632385254, 2.841804265975952, 2.8150277137756348, 4.072852611541748, 3.1639273166656494, 4.083523750305176, 2.416684865951538, 2.008223056793213, 6.599559783935547, 3.8771302700042725, 4.76695442199707, 4.330484390258789, 3.5413968563079834, 4.4807209968566895, 4.40773868560791, 5.183446884155273, 4.339788436889648, 4.2154693603515625, 4.499879360198975, 5.121772766113281, 2.231842279434204]

 

 

 

 

 

參考資料:

1. https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py

2. https://github.com/rcorbish/pytorch-notebooks/blob/master/gan-basic.ipynb

3. https://www.pytorchtutorial.com/pytorch-sample-gan/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章