如何在Keras中開發最大化生成對抗網絡(InfoGAN)的信息?

作者 | Jason Brownlee
編譯 | CDA數據分析師

生成對抗網絡(GAN)是一種用於訓練深度卷積模型以生成合成圖像的體系結構。

儘管非常有效,但默認GAN無法控制生成的圖像類型。信息最大化GAN(簡稱InfoGAN)是GAN架構的擴展,它引入了架構自動學習的控制變量,並允許控制生成的圖像,例如在生成圖像樣式的情況下,厚度和類型手寫的數字。

在本教程中,您將瞭解如何從頭開始實現信息最大化生成對抗網絡模型。

完成本教程後,您將瞭解:

  • InfoGAN的動機是希望解開和控制生成的圖像中的屬性。
  • InfoGAN涉及添加控制變量以生成預測控制變量的輔助模型,通過互信息損失函數進行訓練。
  • 如何從頭開發和訓練InfoGAN模型,並使用控制變量來控制模型生成的數字。

讓我們開始吧。

教程概述

本教程分爲四個部分; 他們是:

  1. 什麼是最大化GAN的信息
  2. 如何實現InfoGAN丟失功能
  3. 如何爲MNIST開發InfoGAN
  4. 如何使用訓練有素的InfoGAN模型使用控制代碼

什麼是最大化GAN的信息

Generative Adversarial Network(簡稱GAN)是一種用於訓練生成模型的體系結構,例如用於生成合成圖像的模型。

它涉及同時訓練生成器模型以生成具有鑑別器模型的圖像,該模型學習將圖像分類爲真實的(來自訓練數據集)或假的(生成的)。這兩個模型在零和遊戲中競爭,使得訓練過程的收斂涉及在生成器生成令人信服的圖像的技能與能夠檢測它們的鑑別器之間找到平衡。

生成器模型將來自潛在空間的隨機點作爲輸入,通常爲50到100個隨機高斯變量。生成器通過訓練對潛在空間中的點應用獨特的含義,並將點映射到特定的輸出合成圖像。這意味着雖然潛在空間由生成器模型構成,但是無法控制生成的圖像。

GAN公式使用簡單的因子連續輸入噪聲向量z,而對發生器可以使用該噪聲的方式沒有限制。結果,發生器可能以高度糾纏的方式使用噪聲,導致z的各個維度不對應於數據的語義特徵。

可以探索潛在空間並比較生成的圖像,以試圖理解生成器模型已經學習的映射。或者,可以例如通過類標籤來調節生成過程,以便可以按需創建特定類型的圖像。這是條件生成對抗網絡的基礎,簡稱CGAN或cGAN。

另一種方法是提供控制變量作爲發電機的輸入,以及潛在空間中的點(噪聲)。可以訓練發生器以使用控制變量來影響所生成圖像的特定屬性。這是信息最大化生成對抗網絡(簡稱InfoGAN)所採用的方法。

InfoGAN,生成對抗網絡的信息理論擴展,能夠以完全無監督的方式學習解纏結的表示。

在訓練過程中由發生器學習的結構化映射有些隨機。雖然生成器模型學習在潛在空間中空間分離生成圖像的屬性,但是沒有控制。這些屬性糾纏在一起。InfoGAN的動機是希望解開生成圖像的屬性。

例如,在面部的情況下,可以解開和控制生成面部的特性,例如面部的形狀,頭髮顏色,髮型等。

例如,對於面部的數據集,有用的解開的表示可以爲以下屬性中的每一個分配一組單獨的維度:面部表情,眼睛顏色,髮型,眼鏡的存在或不存在,以及相應人的身份。

控制變量與噪聲一起提供作爲發電機的輸入,並且通過互信息丟失功能訓練模型。

…我們對生成對抗性網絡目標進行了簡單的修改,鼓勵它學習可解釋和有意義的表達。我們通過最大化GAN噪聲變量的固定小子集與觀測值之間的互信息來實現這一點,結果證明是相對簡單的。

相互信息是指在給定另一個變量的情況下獲得的關於一個變量的信息量。在這種情況下,我們感興趣的是有關使用噪聲和控制變量生成的圖像的控制變量的信息。

在信息論中,X和Y之間的互信息I(X; Y)測量從隨機變量Y的知識中學習的關於另一個隨機變量X 的“ 信息量 ”。

相互信息(MI)被計算爲圖像的條件熵(由發生器(G)從噪聲(z)和控制變量(c)創建),給定控制變量(c)從邊際熵減去控制變量(c); 例如:

  • MI =熵(c) - 熵(c | G(z,c))

在實踐中,計算真實的互信息通常是難以處理的,儘管本文采用了簡化,稱爲變分信息最大化,並且控制代碼的熵保持不變。

通過使用稱爲Q或輔助模型的新模型來實現通過互信息訓練發電機。新模型與用於解釋輸入圖像的鑑別器模型共享所有相同的權重,但與預測圖像是真實還是假的鑑別器模型不同,輔助模型預測用於生成圖像的控制代碼。

兩種模型都用於更新生成器模型,首先是爲了提高生成愚弄鑑別器模型的圖像的可能性,其次是改善用於生成圖像的控制代碼和輔助模型對控制代碼的預測之間的互信息。

結果是生成器模型通過互信息丟失而正規化,使得控制代碼捕獲所生成圖像的顯着特性,並且反過來可以用於控制圖像生成過程。

每當我們有興趣學習從給定輸入X到保留關於原始輸入的信息的更高級別表示Y的參數化映射時,可以利用互信息。[…]表明,最大化互信息的任務基本上等同於訓練自動編碼器以最小化重建誤差。

如何實現InfoGAN丟失功能

一旦熟悉模型的輸入和輸出,InfoGAN就可以相當直接地實現。

唯一的絆腳石可能是互信息丟失功能,特別是如果你沒有像大多數開發人員那樣強大的數學背景。

InfoGan使用兩種主要類型的控制變量:分類和連續,連續變量可能具有不同的數據分佈,這會影響相互損失的計算方式。可以基於變量類型計算所有控制變量的相互損失並將其相加,這是OpenAI爲TensorFlow發佈的官方InfoGAN實現中使用的方法。

在Keras中,將控制變量簡化爲分類和高斯或均勻連續變量可能更容易,並且對於每個控制變量類型在輔助模型上具有單獨的輸出。這樣可以使用不同的損失函數,大大簡化了實現。

有關本節中建議的更多背景信息,請參閱更多閱讀部分中的文章和帖子。

分類控制變量

分類變量可用於控制所生成圖像的類型或類別。

這被實現爲一個熱編碼矢量。也就是說,如果類具有10個值,則控制代碼將是一個類,例如6,並且輸入到生成器模型的分類控制向量將是所有零值的10個元素向量,其中對於類6具有一個值,例如,[0,0,0,0,0,0,1,0,0]。

訓練模型時,我們不需要選擇分類控制變量; 相反,它們是隨機生成的,例如,每個樣本以均勻的概率選擇每個樣本。

…關於潛碼c~Cat(K = 10,p = 0.1)的統一分類分佈

在輔助模型中,分類變量的輸出層也將是一個熱編碼矢量以匹配輸入控制代碼,並且使用softmax激活函數。

對於分類潛在代碼ci,我們使用softmax非線性的自然選擇來表示Q(ci | x)。

回想一下,互信息被計算爲來自控制變量的條件熵和從提供給輸入變量的控制變量的熵中減去的輔助模型的輸出。我們可以直接實現這一點,但這不是必需的。

控制變量的熵是一個常數,並且是一個接近於零的非常小的數; 因此,我們可以從計算中刪除它。條件熵可以直接計算爲控制變量輸入和輔助模型的輸出之間的交叉熵。因此,可以使用分類交叉熵損失函數,就像我們對任何多類分類問題一樣。

超參數lambda用於縮放互信息丟失函數並設置爲1,因此可以忽略。

即使InfoGAN引入了額外的超參數λ,它也很容易調整,簡單地設置爲1就足以支持離散的潛碼。

- InfoGAN:可解釋的代表性信息學習最大化生成性對抗網,2016年。

連續控制變量

連續控制變量可用於控制圖像的樣式。

連續變量從均勻分佈中採樣,例如在-1和1之間,並作爲輸入提供給發電機模型。

…可以捕捉連續性變化的連續代碼:c2,c3~Unif(-1,1)

- InfoGAN:可解釋的代表性信息學習最大化生成性對抗網,2016年。

輔助模型可以用高斯分佈實現連續控制變量的預測,其中輸出層被配置爲具有一個節點,平均值和一個用於高斯標準偏差的節點,例如每個連續控制需要兩個輸出變量。

對於連續潛在代碼cj,根據什麼是真正的後驗P(cj | x),有更多選項。在我們的實驗中,我們發現簡單地將Q(cj | x)視爲因式高斯是足夠的。

輸出均值的節點可以使用線性激活函數,而輸出標準偏差的節點必須產生正值,因此可以使用諸如sigmoid的激活函數來創建0到1之間的值。

對於連續潛碼,我們通過對角高斯分佈對近似後驗進行參數化,識別網絡輸出其均值和標準差,其中標準偏差通過網絡輸出的指數變換進行參數化以確保積極性。

必須將損失函數計算爲高斯控制碼的互信息,這意味着它們必須在計算損失之前從平均值和標準差重建。計算高斯分佈變量的熵和條件熵可以直接實現,但不是必需的。相反,可以使用均方誤差損失。

或者,可以將輸出分佈簡化爲每個控制變量的均勻分佈,可以使用具有線性激活的輔助模型中的每個變量的單個輸出節點,並且模型可以使用均方誤差損失函數。

如何爲MNIST開發InfoGAN

在本節中,我們將仔細研究生成器(g),鑑別器(d)和輔助模型(q)以及如何在Keras中實現它們。

我們將爲MNIST數據集開發InfoGAN實現,如InfoGAN論文中所做的那樣。

本文探討了兩個版本; 第一個僅使用分類控制代碼,並允許模型將一個分類變量映射到大約一個數字(儘管沒有按分類變量排序數字)。

本文還探討了InfoGAN架構的一個版本,其中包含一個熱編碼分類變量(c1)和兩個連續控制變量(c2和c3)。

發現第一個連續變量用於控制數字的旋轉,第二個連續變量用於控制數字的粗細。

我們將重點關注使用具有10個值的分類控制變量的簡單情況,並鼓勵模型學習讓該變量控制生成的數字。您可能希望通過更改分類控制變量的基數或添加連續控制變量來擴展此示例。

用於MNIST數據集訓練的GAN模型的配置作爲本文的附錄提供,轉載如下。我們將使用列出的配置作爲開發我們自己的生成器(g),鑑別器(d)和輔助(q)模型的起點。

讓我們從將生成器模型開發爲深度卷積神經網絡(例如DCGAN)開始。

該模型可以將噪聲向量(z)和控制向量(c)作爲單獨的輸入,並在將它們用作生成圖像的基礎之前將它們連接起來。或者,可以預先將矢量連接起來並提供給模型中的單個輸入層。方法是等價的,在這種情況下我們將使用後者來保持模型簡單。

下面的*define_generator()*函數定義生成器模型,並將輸入向量的大小作爲參數。

完全連接的層採用輸入向量併產生足夠數量的激活,以創建512個7×7特徵映射,從中重新激活激活。然後,它們以1×1步幅通過正常卷積層,然後兩個隨後的上採樣將卷積層轉換爲2×2步幅優先至14×14特徵映射,然後轉換爲所需的1通道28×28特徵映射輸出,其中像素值爲通過tanh激活函數的範圍[-1,-1]。

良好的發生器配置啓發式如下,包括隨機高斯權重初始化,隱藏層中的ReLU激活以及批量歸一化的使用。

# define the standalone generator model
def define_generator(gen_input_size):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image generator input
	in_lat = Input(shape=(gen_input_size,))
	# foundation for 7x7 image
	n_nodes = 512 * 7 * 7
	gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	gen = Reshape((7, 7, 512))(gen)
	# normal
	gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 14x14
	gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 28x28
	gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	# tanh output
	out_layer = Activation('tanh')(gen)
	# define model
	model = Model(in_lat, out_layer)
	return model

接下來,我們可以定義鑑別器和輔助模型。

根據普通GAN,鑑別器模型以獨立方式訓練在真實和僞造圖像上。發電機和輔助模型都不直接配合; 相反,它們適合作爲複合模型的一部分。

鑑別器和輔助模型共享相同的輸入和特徵提取層,但它們的輸出層不同。因此,同時定義它們是有意義的。

同樣,有許多方法可以實現這種架構,但是將鑑別器和輔助模型定義爲單獨的模型首先允許我們稍後通過功能API直接將它們組合成更大的GAN模型。

下面的*define_discriminator()*函數定義了鑑別器和輔助模型,並將分類變量的基數(例如數值,例如10)作爲輸入。輸入圖像的形狀也被參數化爲函數參數,並設置爲MNIST圖像大小的默認值。

特徵提取層涉及兩個下采樣層,而不是池化層作爲最佳實踐。此外,遵循DCGAN模型的最佳實踐,我們使用LeakyReLU激活和批量標準化

鑑別器模型(d)具有單個輸出節點,並通過S形激活函數預測輸入圖像的實際概率。該模型被編譯,因爲它將以獨立的方式使用,通過具有最佳實踐學習速率和動量的隨機梯度下降Adam版本來優化二元交叉熵函數

輔助模型(q)對分類變量中的每個值具有一個節點輸出,並使用softmax激活函數。如InfoGAN論文中所使用的那樣,在特徵提取層和輸出層之間添加完全連接的層。該模型未編譯,因爲它不是獨立使用或以獨立方式使用。

# define the standalone discriminator model
def define_discriminator(n_cat, in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=in_shape)
	# downsample to 14x14
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.1)(d)
	# downsample to 7x7
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# normal
	d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# flatten feature maps
	d = Flatten()(d)
	# real/fake output
	out_classifier = Dense(1, activation='sigmoid')(d)
	# define d model
	d_model = Model(in_image, out_classifier)
	# compile d model
	d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
	# create q model layers
	q = Dense(128)(d)
	q = BatchNormalization()(q)
	q = LeakyReLU(alpha=0.1)(q)
	# q model output
	out_codes = Dense(n_cat, activation='softmax')(q)
	# define q model
	q_model = Model(in_image, out_codes)
	return d_model, q_model

接下來,我們可以定義複合GAN模型。

該模型使用所有子模型,並且是訓練發電機模型權重的基礎。

下面的*define_gan()*函數實現了這個並定義並返回模型,將三個子模型作爲輸入。

如上所述,鑑別器以獨立方式訓練,因此鑑別器的所有權重被設置爲不可訓練(僅在此上下文中)。生成器模型的輸出連接到鑑別器模型的輸入,並連接到輔助模型的輸入。

這將創建一個新的複合模型,該模型將[noise + control]向量作爲輸入,然後通過生成器生成圖像。然後,圖像通過鑑別器模型以產生分類,並通過輔助模型產生控制變量的預測。

該模型有兩個輸出層,需要使用不同的損失函數進行訓練。二進制交叉熵損失用於鑑別器輸出,正如我們在編譯獨立使用的鑑別器時所做的那樣,並且互信息丟失用於輔助模型,在這種情況下,輔助模型可以直接實現爲分類交叉熵並實現期望的結果。

# define the combined discriminator, generator and q network model
def define_gan(g_model, d_model, q_model):
	# make weights in the discriminator (some shared with the q model) as not trainable
	d_model.trainable = False
	# connect g outputs to d inputs
	d_output = d_model(g_model.output)
	# connect g outputs to q inputs
	q_output = q_model(g_model.output)
	# define composite model
	model = Model(g_model.input, [d_output, q_output])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
	return model

爲了使GAN模型架構更清晰,我們可以創建模型和複合模型圖。

下面列出了完整的示例。

# create and plot the infogan model for mnist

from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.initializers import RandomNormal
from keras.utils.vis_utils import plot_model

# define the standalone discriminator model

def define_discriminator(n_cat, in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=in_shape)
	# downsample to 14x14
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.1)(d)
	# downsample to 7x7
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# normal
	d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# flatten feature maps
	d = Flatten()(d)
	# real/fake output
	out_classifier = Dense(1, activation='sigmoid')(d)
	# define d model
	d_model = Model(in_image, out_classifier)
	# compile d model
	d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
	# create q model layers
	q = Dense(128)(d)
	q = BatchNormalization()(q)
	q = LeakyReLU(alpha=0.1)(q)
	# q model output
	out_codes = Dense(n_cat, activation='softmax')(q)
	# define q model
	q_model = Model(in_image, out_codes)
	return d_model, q_model

# define the standalone generator model

def define_generator(gen_input_size):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image generator input
	in_lat = Input(shape=(gen_input_size,))
	# foundation for 7x7 image
	n_nodes = 512 * 7 * 7
	gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	gen = Reshape((7, 7, 512))(gen)
	# normal
	gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 14x14
	gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 28x28
	gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	# tanh output
	out_layer = Activation('tanh')(gen)
	# define model
	model = Model(in_lat, out_layer)
	return model

# define the combined discriminator, generator and q network model

def define_gan(g_model, d_model, q_model):
	# make weights in the discriminator (some shared with the q model) as not trainable
	d_model.trainable = False
	# connect g outputs to d inputs
	d_output = d_model(g_model.output)
	# connect g outputs to q inputs
	q_output = q_model(g_model.output)
	# define composite model
	model = Model(g_model.input, [d_output, q_output])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
	return model

# number of values for the categorical control code

n_cat = 10

# size of the latent space

latent_dim = 62

# create the discriminator

d_model, q_model = define_discriminator(n_cat)

# create the generator

gen_input_size = latent_dim + n_cat
g_model = define_generator(gen_input_size)

# create the gan

gan_model = define_gan(g_model, d_model, q_model)

# plot the model

plot_model(gan_model, to_file='gan_plot.png', show_shapes=True, show_layer_names=True)

運行該示例將創建所有三個模型,然後創建複合GAN模型並保存模型體系結構的圖。

注意:創建此圖假設已安裝pydot和graphviz庫。如果這是一個問題,您可以註釋掉import語句和對*plot_model()*函數的調用。

該圖顯示了生成器模型的所有細節以及鑑別器和輔助模型的壓縮描述。重要的是,請注意鑑別器輸出的形狀作爲預測圖像是真實還是假的單個節點,以及輔助模型預測分類控制代碼的10個節點。

回想一下,該複合模型將僅用於更新生成器和輔助模型的模型權重,並且鑑別器模型中的所有權重將保持不可約,即僅在更新獨立鑑別器模型時更新。

接下來,我們將爲發電機開發輸入。

每個輸入都是由噪聲和控制代碼組成的矢量。具體地,高斯隨機數的矢量和一個熱編碼的隨機選擇的分類值。

下面的*generate_latent_points()*函數實現了這一點,將潛在空間的大小,分類值的數量以及要生成的樣本數作爲參數作爲輸入。該函數返回輸入連接向量作爲生成器模型的輸入,以及獨立控制代碼。通過複合GAN模型更新發電機和輔助模型時,將需要獨立控制代碼,專門用於計算輔助模型的互信息損失。

# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_cat, n_samples):
	# generate points in the latent space
	z_latent = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_latent = z_latent.reshape(n_samples, latent_dim)
	# generate categorical codes
	cat_codes = randint(0, n_cat, n_samples)
	# one hot encode
	cat_codes = to_categorical(cat_codes, num_classes=n_cat)
	# concatenate latent points and control codes
	z_input = hstack((z_latent, cat_codes))
	return [z_input, cat_codes]

接下來,我們可以生成真實和虛假的例子。

可以通過爲灰度圖像添加附加維度來加載MNIST數據集,將其轉換爲3D輸入,並將所有像素值縮放到範圍[-1,1]以匹配來自生成器模型的輸出。這是在下面的*load_real_samples()*函數中實現的。

我們可以通過選擇數據集的隨機子集來檢索訓練鑑別器時所需的批量實際樣本。這在下面的*generate_real_samples()*函數中實現,該函數返回圖像和類標籤1,以向鑑別器指示它們是真實圖像。

鑑別器還需要使用來自*generate_latent_points()函數的向量作爲輸入,通過生成器生成批量僞造樣本。下面的generate_fake_samples()*函數實現了這一點,返回生成的圖像以及類標籤0,以向鑑別器指示它們是僞圖像。

# load images

def load_real_samples():
	# load dataset
	(trainX, _), (_, _) = load_data()
	# expand to 3d, e.g. add channels
	X = expand_dims(trainX, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	print(X.shape)
	return X

# select real samples

def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images and labels
	X = dataset[ix]
	# generate class labels
	y = ones((n_samples, 1))
	return X, y

# use the generator to generate n fake examples, with class labels

def generate_fake_samples(generator, latent_dim, n_cat, n_samples):
	# generate points in latent space and control codes
	z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)
	# predict outputs
	images = generator.predict(z_input)
	# create class labels
	y = zeros((n_samples, 1))
	return images, y

接下來,我們需要跟蹤生成的圖像的質量。

我們將定期使用生成器生成圖像樣本,並將生成器和複合模型保存到文件中。然後,我們可以在訓練結束時查看生成的圖像,以便選擇最終的生成器模型並加載模型以開始使用它來生成圖像。

下面的*summarize_performance()*函數實現了這一點,首先生成100個圖像,將它們的像素值縮放回範圍[0,1],並將它們保存爲10×10平方的圖像圖。

生成器和複合GAN模型也保存到文件中,具有基於訓練迭代次數的唯一文件名。

# generate samples and save as a plot and save the model

def summarize_performance(step, g_model, gan_model, latent_dim, n_cat, n_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_cat, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(100):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	# save the gan model
	filename3 = 'gan_model_%04d.h5' % (step+1)
	gan_model.save(filename3)
	print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))

最後,我們可以培訓InfoGAN。

這在下面的*train()*函數中實現,該函數將定義的模型和配置作爲參數並運行訓練過程。

模型訓練100個時期,每批使用64個樣本。MNIST訓練數據集中有60,000個圖像,因此一個時期涉及60,000/64,或937批次或訓練迭代。將其乘以時期數或100,意味着總共將有93,700次訓練迭代次數。

每次訓練迭代包括首先用半批真實樣本和半批假樣本更新鑑別器,以形成一批重量更新,或每次迭代64次。接下來,基於批量噪聲和控制代碼輸入更新複合GAN模型。每次訓練迭代都會報告真實和假圖像上的鑑別器的丟失以及發生器和輔助模型的丟失。

# train the generator and discriminator

def train(g_model, d_model, gan_model, dataset, latent_dim, n_cat, n_epochs=100, n_batch=64):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# manually enumerate epochs
	for i in range(n_steps):
		# get randomly selected 'real' samples
		X_real, y_real = generate_real_samples(dataset, half_batch)
		# update discriminator and q model weights
		d_loss1 = d_model.train_on_batch(X_real, y_real)
		# generate 'fake' examples
		X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_cat, half_batch)
		# update discriminator model weights
		d_loss2 = d_model.train_on_batch(X_fake, y_fake)
		# prepare points in latent space as input for the generator
		z_input, cat_codes = generate_latent_points(latent_dim, n_cat, n_batch)
		# create inverted labels for the fake samples
		y_gan = ones((n_batch, 1))
		# update the g via the d and q error
		_,g_1,g_2 = gan_model.train_on_batch(z_input, [y_gan, cat_codes])
		# summarize loss on this batch
		print('>%d, d[%.3f,%.3f], g[%.3f] q[%.3f]' % (i+1, d_loss1, d_loss2, g_1, g_2))
		# evaluate the model performance every 'epoch'
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, gan_model, latent_dim, n_cat)

然後我們可以配置和創建模型,然後運行培訓過程。

我們將使用10個值作爲單個分類變量來匹配MNIST數據集中的10個已知類。我們將使用64維的潛在空間來匹配InfoGAN論文,這意味着,在這種情況下,生成器模型的每個輸入向量將是64(隨機高斯變量)+ 10(一個熱編碼控制變量)或72個元素長度。

# number of values for the categorical control code

n_cat = 10

# size of the latent space

latent_dim = 62

# create the discriminator

d_model, q_model = define_discriminator(n_cat)

# create the generator

gen_input_size = latent_dim + n_cat
g_model = define_generator(gen_input_size)

# create the gan

gan_model = define_gan(g_model, d_model, q_model)

# load image data

dataset = load_real_samples()

# train model

train(g_model, d_model, gan_model, dataset, latent_dim, n_cat)

將這一點結合在一起,下面列出了使用單個分類控制變量在MNIST數據集上訓練InfoGAN模型的完整示例。

# example of training an infogan on mnist

from numpy import zeros
from numpy import ones
from numpy import expand_dims
from numpy import hstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.utils import to_categorical
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.layers import Activation
from matplotlib import pyplot

# define the standalone discriminator model

def define_discriminator(n_cat, in_shape=(28,28,1)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=in_shape)
	# downsample to 14x14
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.1)(d)
	# downsample to 7x7
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# normal
	d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
	d = LeakyReLU(alpha=0.1)(d)
	d = BatchNormalization()(d)
	# flatten feature maps
	d = Flatten()(d)
	# real/fake output
	out_classifier = Dense(1, activation='sigmoid')(d)
	# define d model
	d_model = Model(in_image, out_classifier)
	# compile d model
	d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
	# create q model layers
	q = Dense(128)(d)
	q = BatchNormalization()(q)
	q = LeakyReLU(alpha=0.1)(q)
	# q model output
	out_codes = Dense(n_cat, activation='softmax')(q)
	# define q model
	q_model = Model(in_image, out_codes)
	return d_model, q_model

# define the standalone generator model

def define_generator(gen_input_size):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image generator input
	in_lat = Input(shape=(gen_input_size,))
	# foundation for 7x7 image
	n_nodes = 512 * 7 * 7
	gen = Dense(n_nodes, kernel_initializer=init)(in_lat)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	gen = Reshape((7, 7, 512))(gen)
	# normal
	gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 14x14
	gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	gen = Activation('relu')(gen)
	gen = BatchNormalization()(gen)
	# upsample to 28x28
	gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
	# tanh output
	out_layer = Activation('tanh')(gen)
	# define model
	model = Model(in_lat, out_layer)
	return model

# define the combined discriminator, generator and q network model

def define_gan(g_model, d_model, q_model):
	# make weights in the discriminator (some shared with the q model) as not trainable
	d_model.trainable = False
	# connect g outputs to d inputs
	d_output = d_model(g_model.output)
	# connect g outputs to q inputs
	q_output = q_model(g_model.output)
	# define composite model
	model = Model(g_model.input, [d_output, q_output])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
	return model

# load images

def load_real_samples():
	# load dataset
	(trainX, _), (_, _) = load_data()
	# expand to 3d, e.g. add channels
	X = expand_dims(trainX, axis=-1)
	# convert from ints to floats
	X = X.astype('float32')
	# scale from [0,255] to [-1,1]
	X = (X - 127.5) / 127.5
	print(X.shape)
	return X

# select real samples

def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images and labels
	X = dataset[ix]
	# generate class labels
	y = ones((n_samples, 1))
	return X, y

# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_cat, n_samples):
	# generate points in the latent space
	z_latent = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_latent = z_latent.reshape(n_samples, latent_dim)
	# generate categorical codes
	cat_codes = randint(0, n_cat, n_samples)
	# one hot encode
	cat_codes = to_categorical(cat_codes, num_classes=n_cat)
	# concatenate latent points and control codes
	z_input = hstack((z_latent, cat_codes))
	return [z_input, cat_codes]

# use the generator to generate n fake examples, with class labels

def generate_fake_samples(generator, latent_dim, n_cat, n_samples):
	# generate points in latent space and control codes
	z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)
	# predict outputs
	images = generator.predict(z_input)
	# create class labels
	y = zeros((n_samples, 1))
	return images, y

# generate samples and save as a plot and save the model

def summarize_performance(step, g_model, gan_model, latent_dim, n_cat, n_samples=100):
	# prepare fake examples
	X, _ = generate_fake_samples(g_model, latent_dim, n_cat, n_samples)
	# scale from [-1,1] to [0,1]
	X = (X + 1) / 2.0
	# plot images
	for i in range(100):
		# define subplot
		pyplot.subplot(10, 10, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
	# save plot to file
	filename1 = 'generated_plot_%04d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	# save the gan model
	filename3 = 'gan_model_%04d.h5' % (step+1)
	gan_model.save(filename3)
	print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))

# train the generator and discriminator

def train(g_model, d_model, gan_model, dataset, latent_dim, n_cat, n_epochs=100, n_batch=64):
	# calculate the number of batches per training epoch
	bat_per_epo = int(dataset.shape[0] / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# calculate the size of half a batch of samples
	half_batch = int(n_batch / 2)
	# manually enumerate epochs
	for i in range(n_steps):
		# get randomly selected 'real' samples
		X_real, y_real = generate_real_samples(dataset, half_batch)
		# update discriminator and q model weights
		d_loss1 = d_model.train_on_batch(X_real, y_real)
		# generate 'fake' examples
		X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_cat, half_batch)
		# update discriminator model weights
		d_loss2 = d_model.train_on_batch(X_fake, y_fake)
		# prepare points in latent space as input for the generator
		z_input, cat_codes = generate_latent_points(latent_dim, n_cat, n_batch)
		# create inverted labels for the fake samples
		y_gan = ones((n_batch, 1))
		# update the g via the d and q error
		_,g_1,g_2 = gan_model.train_on_batch(z_input, [y_gan, cat_codes])
		# summarize loss on this batch
		print('>%d, d[%.3f,%.3f], g[%.3f] q[%.3f]' % (i+1, d_loss1, d_loss2, g_1, g_2))
		# evaluate the model performance every 'epoch'
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, gan_model, latent_dim, n_cat)

# number of values for the categorical control code

n_cat = 10

# size of the latent space

latent_dim = 62

# create the discriminator

d_model, q_model = define_discriminator(n_cat)

# create the generator

gen_input_size = latent_dim + n_cat
g_model = define_generator(gen_input_size)

# create the gan

gan_model = define_gan(g_model, d_model, q_model)

# load image data

dataset = load_real_samples()

# train model

train(g_model, d_model, gan_model, dataset, latent_dim, n_cat)

運行該示例可能需要一些時間,建議使用GPU硬件,但不是必需的。

注意:鑑於訓練算法的隨機性,您的結果可能會有所不同。嘗試運行幾次示例。

每次訓練迭代都會報告模型中的損失。如果鑑別器的損失保持在0.0或長時間變爲0.0,這可能是訓練失敗的跡象,您可能想要重新開始訓練過程。鑑別器損失可能從0.0開始,但可能會增加,就像在這種特定情況下一樣。

輔助模型的損失可能會歸零,因爲它可以完美地預測分類變量。發電機和鑑別器模型的損失最終可能會在1.0左右徘徊,以展示穩定的訓練過程或兩種模型訓練之間的平衡。

1, d[0.924,0.758], g[0.448] q[2.909]
2, d[0.000,2.699], g[0.547] q[2.704]
3, d[0.000,1.557], g[1.175] q[2.820]
4, d[0.000,0.941], g[1.466] q[2.813]
5, d[0.000,1.013], g[1.908] q[2.715]
...
93696, d[0.814,1.212], g[1.283] q[0.000]
93697, d[1.063,0.920], g[1.132] q[0.000]
93698, d[0.999,1.188], g[1.128] q[0.000]
93699, d[0.935,0.985], g[1.229] q[0.000]
93700, d[0.968,1.016], g[1.200] q[0.001]
Saved: generated_plot_93700.png, model_93700.h5, and gan_model_93700.h5

每10個時期或每9,370次訓練迭代中保存圖和模型。

回顧這些圖表應該顯示早期時代的低質量圖像以及後期時代的改進和穩定質量的圖像。

例如,在前10個時期之後保存的圖像的圖表低於顯示低質量的生成圖像。

更多時代並不意味着更好的質量,這意味着最佳質量的圖像可能不是來自訓練結束時保存的最終模型的圖像。

查看圖表並選擇具有最佳圖像質量的最終模型。在這種情況下,我們將使用在100個紀元或93,700次訓練迭代後保存的模型。

如何使用訓練有素的InfoGAN模型使用控制代碼

現在我們已經培訓了InfoGAN模型,我們可以探索如何使用它。

首先,我們可以加載模型並使用它來生成隨機圖像,就像我們在訓練期間所做的那樣。

下面列出了完整的示例。

更改模型文件名以匹配在訓練期間生成最佳圖像的模型文件名。

# example of loading the generator model and generating images

from math import sqrt
from numpy import hstack
from numpy.random import randn
from numpy.random import randint
from keras.models import load_model
from keras.utils import to_categorical
from matplotlib import pyplot

# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_cat, n_samples):
	# generate points in the latent space
	z_latent = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_latent = z_latent.reshape(n_samples, latent_dim)
	# generate categorical codes
	cat_codes = randint(0, n_cat, n_samples)
	# one hot encode
	cat_codes = to_categorical(cat_codes, num_classes=n_cat)
	# concatenate latent points and control codes
	z_input = hstack((z_latent, cat_codes))
	return [z_input, cat_codes]

# create a plot of generated images

def create_plot(examples, n_examples):
	# plot images
	for i in range(n_examples):
		# define subplot
		pyplot.subplot(sqrt(n_examples), sqrt(n_examples), 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
	pyplot.show()

# load model

model = load_model('model_93700.h5')

# number of values for the categorical control code

n_cat = 10

# size of the latent space

latent_dim = 62

# number of examples to generate

n_samples = 100

# generate points in latent space and control codes

z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)

# predict outputs

X = model.predict(z_input)

# scale from [-1,1] to [0,1]

X = (X + 1) / 2.0

# plot the result

create_plot(X, n_samples)

運行該示例將加載已保存的生成器模型並使用它生成100個隨機圖像並將圖像繪製在10×10網格上。

接下來,我們可以更新示例以測試控制變量給我們的控制程度。

我們可以更新*generate_latent_points()*函數,以獲取[0,9]中分類值的參數,對其進行編碼,並將其與噪聲向量一起用作輸入。

# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_cat, n_samples, digit):
	# generate points in the latent space
	z_latent = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_latent = z_latent.reshape(n_samples, latent_dim)
	# define categorical codes
	cat_codes = asarray([digit for _ in range(n_samples)])
	# one hot encode
	cat_codes = to_categorical(cat_codes, num_classes=n_cat)
	# concatenate latent points and control codes
	z_input = hstack((z_latent, cat_codes))
	return [z_input, cat_codes]

我們可以通過生成具有分類值1的25個圖像的網格來測試這一點。

下面列出了完整的示例。

# example of testing different values of the categorical control variable

from math import sqrt
from numpy import asarray
from numpy import hstack
from numpy.random import randn
from numpy.random import randint
from keras.models import load_model
from keras.utils import to_categorical
from matplotlib import pyplot

# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n_cat, n_samples, digit):
	# generate points in the latent space
	z_latent = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	z_latent = z_latent.reshape(n_samples, latent_dim)
	# define categorical codes
	cat_codes = asarray([digit for _ in range(n_samples)])
	# one hot encode
	cat_codes = to_categorical(cat_codes, num_classes=n_cat)
	# concatenate latent points and control codes
	z_input = hstack((z_latent, cat_codes))
	return [z_input, cat_codes]

# create and save a plot of generated images

def save_plot(examples, n_examples):
	# plot images
	for i in range(n_examples):
		# define subplot
		pyplot.subplot(sqrt(n_examples), sqrt(n_examples), 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
	pyplot.show()

# load model

model = load_model('model_93700.h5')

# number of categorical control codes

n_cat = 10

# size of the latent space

latent_dim = 62

# number of examples to generate

n_samples = 25

# define digit

digit = 1

# generate points in latent space and control codes

z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples, digit)

# predict outputs

X = model.predict(z_input)

# scale from [-1,1] to [0,1]

X = (X + 1) / 2.0

# plot the result

save_plot(X, n_samples)

結果是生成25個生成圖像的網格,其中分類代碼設置爲值1。

注意:鑑於訓練算法的隨機性,您的結果可能會有所不同。

期望控制代碼的值影響生成的圖像; 特別是,他們預計會影響數字類型。但是,預計它們不會被訂購,例如,控制代碼1,2和3來創建這些數字。

然而,在這種情況下,值爲1的控制代碼導致生成的圖像看起來像1。

嘗試使用不同的數字並查看值對圖像的確切控制。

例如,在這種情況下將值設置爲5(數字= 5)會導致生成的圖像看起來像數字“ 8 ”。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章