GAN代碼解析(tensorflow實現)_手寫數字圖片生成
基於py3.0支持中文名方法, 如果報錯請把中文方法名,改爲英文的
#coding:utf-8
# MNIST數據集
# MNIST數據集的官網是Yann LeCun’s website。在這裏,我們提供了一份python源代碼用於自動下載和安裝這個數據集。
# 你可以下載這份代碼,然後用下面的代碼導入到你的項目裏面,也可以直接複製粘貼到你的代碼文件裏面。
# import input_data
# mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 下載下來的數據集被分成兩部分:60000行的訓練數據集(mnist.train)和10000行的測試數據集(mnist.test)。
#
# 每一張圖片包含28像素X28像素。我們可以用一個數字數組來表示這張圖片:
#
# 我們把這個數組展開成一個向量,長度是 28x28 = 784
#
# 因此,在MNIST訓練數據集中,mnist.train.images 是一個形狀爲 [60000, 784] 的張量,
# 第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。在此張量裏的每一個元素,都表示某張圖片裏的某個像素的強度值,值介於0和1之間。
#
# 相對應的MNIST數據集的標籤是介於0到9的數字,用來描述給定圖片裏表示的數字.因此, mnist.train.labels 是一個 [60000, 10] 的數字矩陣。
# Dropout中隱層節點的忽略比例主要作用在隱層節點,是按照一定比例,隨機地使部分隱層節點失效,並且該比例與最後通過模型平均來求得最後的預測值也有一定的關係。
# DAE中加噪比例作用於輸入層,是按照一定比例,對每個網絡的輸入數據加入噪聲,使得自動編碼器通過學習獲得真正的沒有被噪聲污染過的輸入。這種加入噪聲的思想,並不需要進行模型平均。
# tf.slice()介紹
# 函數:tf.slice(inputs, begin, size, name)
# 作用:從列表、數組、張量等對象中抽取一部分數據
# begin和size是兩個多維列表,他們共同決定了要抽取的數據的開始和結束位置
# begin表示從inputs的哪幾個維度上的哪個元素開始抽取
# size表示在inputs的各個維度上抽取的元素個數
# 若begin[]或size[]中出現-1,表示抽取對應維度上的所有元素
# import tensorflow as tf
# import numpy as np
# x=[[1,2,3],[4,5,6]]
# with tf.Session() as sess:
# begin = [0,1] # 從x[0,1],即元素2開始抽取
# size = [2,1] # 從x[0,1]開始,對x的第一個維度(行)抽取2個元素,在對x的第二個維度(列)抽取1個元素
# print sess.run(tf.slice(x,begin,size)) # 輸出[[2 5]]
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #導入圖片數據
import numpy as np
from skimage.io import imsave #讀圖片
import os
import shutil #圖片處理
import sys
img_height = 28
img_width = 28
img_size = img_height * img_width #圖片像素28X28 ,做全鏈接拉伸784
to_train = True #訓練開關
to_restore = False #保存模型開關
output_path = "./output" #保存模型路徑
# 總迭代次數500
max_epoch = 500
#隱層神經元的個數
h1_size = 150 #隱層神經元第一層神經元的個數
h2_size = 300 #隱層神經元第2層神經元的個數
z_size = 100 #輸入的噪音點(輸入也爲100)
batch_size = 256 #batch_size 一次256張圖片。 判別模型有512張(真假各一半)
# generate (model 1) 這裏用的全連接
def 生成模型(z_prior):#build_generator #初始化W,b參數 ,剛開始輸入層是100,hide1 150,hide2 是300
w1 = tf.Variable(tf.truncated_normal(shape=[z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32) #truncated_normal隨機生成函數產生正太分佈的W,這是一個截斷的產生正太分佈的函數,就是說產生正太分佈的值如果與均值的差值大於兩倍的標準差,那就重新生成。
b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)
w2 = tf.Variable(tf.truncated_normal(shape=[h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)
w3 = tf.Variable(tf.truncated_normal(shape=[h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32)
b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32) #b3的大小相當於圖片拉直後的像素點個數784
h3 = tf.matmul(h2, w3) + b3
x_generate = tf.nn.tanh(h3) #生成784維的數
g_params = [w1, b1, w2, b2, w3, b3]
return x_generate, g_params
# discriminator (model 2)
# x_data是真是值
# keep_prob 隱層節點的忽略比例 ,dropout 比例
def 判別模型(x_data, x_generated, keep_prob):#build_discriminator
# tf.concat
x_in = tf.concat([x_data, x_generated], 0)#兩倍於生成模型數量的圖片,因爲要參雜真是圖片
w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32) #784X300
b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)#W1 784X300
w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)#300X150
b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)
w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)#150
b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
h3 = tf.matmul(h2, w3) + b3#是輸出一個數判別生成真是圖片的概率
y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))#切片,代表的是真實的數據y_data ,[batch_size, -1]256行,-1代表所有列,生成的256
y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))#原圖形和生成圖形一一對應,[-1, -1]取剩下的所有行所有列
d_params = [w1, b1, w2, b2, w3, b3]
return y_data, y_generated, d_params
# 保存圖片進度及保存圖片算法
#show_result(x_gen_val, "output_random/random_sample{0}.jpg".format(i))
# grid_pad=5 沒有填充的地方用5去填
def 展示結果保存(batch_res, fname, grid_size=(8, 8), grid_pad=5):#show_result
#數字轉化爲圖片
batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5 #除以2 reshape 784維的常量 變爲28x28,分別作爲圖片像素的寬高,保存下來.如下,做圖像還原時要重新加上均值0.5
img_h, img_w = batch_res.shape[1], batch_res.shape[2]
grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
for i, res in enumerate(batch_res):
if i >= grid_size[0] * grid_size[1]:
break
img = (res) * 255
img = img.astype(np.uint8)
row = (i // grid_size[0]) * (img_h + grid_pad)
col = (i % grid_size[1]) * (img_w + grid_pad)
img_grid[row:row + img_h, col:col + img_w] = img
imsave(fname, img_grid) #保存圖
def 開始訓練():
# load data(mnist手寫數據集)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data") #真實值256X784 ,一次傳的數據
z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") #輸入
keep_prob = tf.placeholder(tf.float32, name="keep_prob")#dropout 比例0.7
global_step = tf.Variable(0, name="global_step", trainable=False)#總共迭代多少步?反向更新多少次
# 創建生成模型
x_generated, g_params = 生成模型(z_prior)
# 創建判別模型
y_data, y_generated, d_params = 判別模型(x_data, x_generated, keep_prob)
# 損失函數的設置
d_loss = - (tf.log(y_data) + tf.log(1 - y_generated)) #整個數據的交叉熵
g_loss = - tf.log(y_generated)#生成器的損失函數,計算的就是生產數據的交叉熵
optimizer = tf.train.AdamOptimizer(0.0001)
# 兩個模型的優化函數
d_trainer = optimizer.minimize(d_loss, var_list=d_params)
g_trainer = optimizer.minimize(g_loss, var_list=g_params)
init = tf.initialize_all_variables()
# init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 啓動默認圖
sess = tf.Session()
# 初始化圖
sess.run(init)
#tensorflow模型持久化
if to_restore:
chkpt_fname = tf.train.latest_checkpoint(output_path)
print(chkpt_fname)
saver.restore(sess, chkpt_fname)
else:
if os.path.exists(output_path):
shutil.rmtree(output_path)#如果存在則刪除
os.mkdir(output_path)#如果不存在則創建
z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)#接着從0-1均勻分佈中抽取了z(至於爲什麼用這個分佈,可以去查看一個概率論,幾乎所有重要的概率分佈都可以從均勻分佈Uniform(0,1)中生成出來)
steps = 60000 / batch_size #訓練集中圖片的數量60000,
for i in range(sess.run(global_step), max_epoch):
for j in np.arange(steps):
# for j in range(steps):
print("epoch:%s, iter:%s" % (i, j))
# 每一步迭代,我們都會加載256個訓練樣本,然後執行一次train_step
x_value, _ = mnist.train.next_batch(batch_size)
x_value = 2 * x_value.astype(np.float32) - 1#python是從0開始的
z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# 執行生成
sess.run(d_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
# 執行判別
if j % 1 == 0: #每個bitch_size打印一次
sess.run(g_trainer,
feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
展示結果保存(x_gen_val, "./output_sample/sample{0}.jpg".format(i))
#以下三句可以省略
z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
展示結果保存(x_gen_val, "./output_random/random_sample{0}.jpg".format(i))
sess.run(tf.assign(global_step, i + 1)) #tf.assign(A, new_number): 這個函數的功能主要是把A的值變爲new_number,賦值作用
saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)
# def test():
# z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")
# x_generated, _ = build_generator(z_prior)
# chkpt_fname = tf.train.latest_checkpoint(output_path)
#
# init = tf.initialize_all_variables()
# sess = tf.Session()
# saver = tf.train.Saver()
# sess.run(init)
# saver.restore(sess, chkpt_fname)
# z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
# x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value})
# show_result(x_gen_val, "output/test_result.jpg")
if __name__ == '__main__':
# if to_train:
# train()
# else:
# test()
開始訓練()
random_sample314.jpg