該項目源碼位於:
https://github.com/jazzsaxmafia/dcgan_tensorflow/tree/master/mnist 項目共有四個python文件,分別是util.py,load.py,model.py和train.py
util.py文件把項目中用到的工具函數整理到一起,包括onehot,圖像剪裁,圖像拼接
# !/usr/bin/python
# -*- coding: utf-8 -*-
import cv2
import scipy.misc
import ipdb
import numpy as np
def OneHot(X, n=None, negative_class=0.):
#將輸入的矩陣壓平,轉化成向量l
X = np.asarray(X).flatten()
#得到輸入數據中最大的值
if n is None:
n = np.max(X) + 1
#生成l*n的矩陣,初始箇中元素都置爲0
Xoh = np.ones((len(X), n)) * negative_class
#將矩陣每一行中X的值所對應的第n個位置置1,其他位爲0
Xoh[np.arange(len(X)), X] = 1.
return Xoh
# 例:生成如下的矩陣
# [[ 0. 0. 0. ..., 0. 0. 0.]
# [ 0. 0. 0. ..., 0. 0. 1.]
# [ 1. 0. 0. ..., 0. 0. 0.]
# ...,
# [ 0. 0. 0. ..., 0. 0. 1.]
# [ 0. 0. 0. ..., 0. 0. 0.]
# [ 0. 0. 0. ..., 0. 0. 0.]]
def crop_resize(image_path, resize_shape=(64,64)):
#resize_image 必須是個正方形
image = cv2.imread(image_path)
height, width, channel = image.shape #得到輸入圖像初始的參數
#if width != height, 新正方形的邊長是原來圖像較短的那一條邊
#resize_image 位於原來圖像的正中間區域
if width == height:
resized_image = cv2.resize(image, resize_shape)
elif width > height:
resized_image = cv2.resize(image, (int(width * float(resize_shape[0])/height), resize_shape[1]))
cropping_length = int( (resized_image.shape[1] - resize_shape[0]) / 2)
resized_image = resized_image[:,cropping_length:cropping_length+resize_shape[1]]
else:
resized_image = cv2.resize(image, (resize_shape[0], int(height * float(resize_shape[1])/width)))
cropping_length = int( (resized_image.shape[0] - resize_shape[1]) / 2)
resized_image = resized_image[cropping_length:cropping_length+resize_shape[0], :]
return resized_image/127.5 - 1#把返回的圖像中每個像素值的範圍壓縮在-1到1之間
#把原圖爲28x28的圖像拼湊起來組成一個392x392的圖像
def save_visualization(X, (nh, nw), save_path='./vis/sample.jpg'):
h,w = X.shape[1], X.shape[2]
img = np.zeros((h * nh, w * nw, 3))
for n,x in enumerate(X):
j = n / nw
i = n % nw
img[j*h:j*h+h, i*w:i*w+w, :] = x
scipy.misc.imsave(save_path, img)
load.py文件的作用是將MNIST數據導入並進行數據集的劃分
# !/usr/bin/python
# -*- coding: utf-8 -*-
import sys
sys.path.append('..')
import numpy as np
import os
from time import time
from collections import Counter
import random
from matplotlib import pyplot as plt
data_dir = '/Users/Chris/MNIST_data'
def mnist():
#訓練集有六萬組數據
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000))
#測試集有一萬組數據
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000))
trY = np.asarray(trY)
teY = np.asarray(teY)
return trX, teX, trY, teY
def mnist_with_valid_set():
trX, teX, trY, teY = mnist()
#從訓練集取出後1萬組作爲驗證集,前5萬組作爲新的測試集
train_inds = range(len(trX))
np.random.shuffle(train_inds)
trX = trX[train_inds]
trY = trY[train_inds]
#trX, trY = shuffle(trX, trY)
vaX = trX[50000:]
vaY = trY[50000:]
trX = trX[:50000]
trY = trY[:50000]
return trX, vaX, teX, trY, vaY, teY
model.py是模型的網絡結構,該文件還包括模型會用到的激活函數和正則化方法等
# !/usr/bin/python
#-*- coding: utf-8 -*-
import tensorflow as tf
import ipdb
def batchnormalize(X, eps=1e-8, g=None, b=None):
if X.get_shape().ndims == 4:
mean = tf.reduce_mean(X, [0,1,2])
std = tf.reduce_mean( tf.square(X-mean), [0,1,2] )
X = (X-mean) / tf.sqrt(std+eps)
if g is not None and b is not None:
g = tf.reshape(g, [1,1,1,-1])
b = tf.reshape(b, [1,1,1,-1])
X = X*g + b
elif X.get_shape().ndims == 2:
mean = tf.reduce_mean(X, 0)
std = tf.reduce_mean(tf.square(X-mean), 0)
X = (X-mean) / tf.sqrt(std+eps)
if g is not None and b is not None:
g = tf.reshape(g, [1,-1])
b = tf.reshape(b, [1,-1])
X = X*g + b
else:
raise NotImplementedError
return X
#此模型實用的激活函數,形狀類似rule,但與rule不同的是x的負半軸那一部分的線段有斜率
def lrelu(X, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * X + f2 * tf.abs(X)
def bce(o, t):
o = tf.clip_by_value(o, 1e-7, 1. - 1e-7)#將o中各個元素的值都壓縮在1e-7至1. - 1e-7之間
return -(t * tf.log(o) + (1.- t)*tf.log(1. - o))#交叉熵
class DCGAN():
def __init__(
self,
batch_size=100,
image_shape=[28,28,1],
dim_z=100,#噪聲z的維度
dim_y=10, #數字0 - 9, 10類
dim_W1=1024,
dim_W2=128,
dim_W3=64,
dim_channel=1,#灰度圖像的通道數
):
self.batch_size = batch_size
self.image_shape = image_shape
self.dim_z = dim_z
self.dim_y = dim_y
self.dim_W1 = dim_W1
self.dim_W2 = dim_W2
self.dim_W3 = dim_W3
self.dim_channel = dim_channel
#初始化參數,在這裏設置name是爲了訓練時獲取參數列表
self.gen_W1 = tf.Variable(tf.random_normal([dim_z+dim_y, dim_W1], stddev=0.02), name='gen_W1')
self.gen_W2 = tf.Variable(tf.random_normal([dim_W1+dim_y, dim_W2*7*7], stddev=0.02), name='gen_W2')
self.gen_W3 = tf.Variable(tf.random_normal([5,5,dim_W3,dim_W2+dim_y], stddev=0.02), name='gen_W3')
self.gen_W4 = tf.Variable(tf.random_normal([5,5,dim_channel,dim_W3+dim_y], stddev=0.02), name='gen_W4')
#64個大小爲5x5的卷積核,通道數爲11
self.discrim_W1 = tf.Variable(tf.random_normal([5,5,dim_channel+dim_y,dim_W3], stddev=0.02), name='discrim_W1')
#128個大小爲5x5的卷積核,通道數爲74
self.discrim_W2 = tf.Variable(tf.random_normal([5,5,dim_W3+dim_y,dim_W2], stddev=0.02), name='discrim_W2')
self.discrim_W3 = tf.Variable(tf.random_normal([dim_W2*7*7+dim_y,dim_W1], stddev=0.02), name='discrim_W3')
self.discrim_W4 = tf.Variable(tf.random_normal([dim_W1+dim_y,1], stddev=0.02), name='discrim_W4')
def build_model(self):
Z = tf.placeholder(tf.float32, [self.batch_size, self.dim_z])
Y = tf.placeholder(tf.float32, [self.batch_size, self.dim_y])
image_real = tf.placeholder(tf.float32, [self.batch_size]+self.image_shape)
image_gen = self.generate(Z,Y)
p_real = self.discriminate(image_real, Y)#真實圖像判別概率
p_gen = self.discriminate(image_gen, Y)#生成圖像判別概率
#判別網絡的目標是讓真實圖像的判別概率越接近1同時生成圖像判別概率越接近0
discrim_cost_real = bce(p_real, tf.ones_like(p_real))
discrim_cost_gen = bce(p_gen, tf.zeros_like(p_gen))
discrim_cost = tf.reduce_mean(discrim_cost_real) + tf.reduce_mean(discrim_cost_gen)
#生成網絡的目標是讓生成圖像判別概率越接近1
gen_cost = tf.reduce_mean(bce( p_gen, tf.ones_like(p_gen) ))
return Z, Y, image_real, discrim_cost, gen_cost, p_real, p_gen
def discriminate(self, image, Y):
yb = tf.reshape(Y, tf.stack([self.batch_size, 1, 1, self.dim_y]))#shape=(128, 1, 1, 10)
X = tf.concat([image, yb*tf.ones([self.batch_size, 28, 28, self.dim_y])],3)#shape=(128, 28, 28, 11)
#第一層是個卷積層,得到64個14x14的feature map
h1 = lrelu( tf.nn.conv2d( X, self.discrim_W1, strides=[1,2,2,1], padding='SAME' ))#shape=(128, 14, 14, 64)
h1 = tf.concat([h1, yb*tf.ones([self.batch_size, 14, 14, self.dim_y])],3)#shape=(128, 14, 14, 74)
#第二層是個卷積層,得到128個7x7的feature map
h2 = lrelu( batchnormalize( tf.nn.conv2d( h1, self.discrim_W2, strides=[1,2,2,1], padding='SAME')) )#shape=(128, 7, 7, 128)
h2 = tf.reshape(h2, [self.batch_size, -1])#shape=(128, 6272)
h2 = tf.concat([h2, Y],1)#shape=(128, 6282)
#第三層是個全連接層,得到1024個輸出節點
h3 = lrelu( batchnormalize( tf.matmul(h2, self.discrim_W3 ) ))#shape=(128, 1024)
h3 = tf.concat([h3, Y],1)#shape=(128, 1034)
y = tf.nn.sigmoid(h3)#shape=(128, 1034)
return y
def generate(self, Z, Y):
#是discriminate()的逆過程
yb = tf.reshape(Y, [self.batch_size, 1, 1, self.dim_y])#shape=(128, 1, 1, 10)
Z = tf.concat([Z,Y],1)#shape=(128, 110)
#第一層是個全連接層
h1 = tf.nn.relu(batchnormalize(tf.matmul(Z, self.gen_W1)))#shape=(128, 1024)
h1 = tf.concat([h1, Y],1)#shape=(128, 1034)
#第二層是個全連接層
h2 = tf.nn.relu(batchnormalize(tf.matmul(h1, self.gen_W2)))#shape=(128, 6272)
h2 = tf.reshape(h2, [self.batch_size,7,7,self.dim_W2])#shape=(128, 7, 7, 128)
h2 = tf.concat([h2, yb*tf.ones([self.batch_size, 7, 7, self.dim_y])],3)#shape=(128, 7, 7, 138)
#第三層是個反捲積層
output_shape_l3 = [self.batch_size,14,14,self.dim_W3]
h3 = tf.nn.conv2d_transpose(h2, self.gen_W3, output_shape=output_shape_l3, strides=[1,2,2,1])#shape=(128, 14, 14, 64)
h3 = tf.nn.relu( batchnormalize(h3) )#shape=(128, 14, 14, 64)
h3 = tf.concat([h3, yb*tf.ones([self.batch_size, 14,14,self.dim_y])],3 )#shape=(128, 14, 14, 74)
#第四層是個反捲積層,得到28x28x1的圖像
output_shape_l4 = [self.batch_size,28,28,self.dim_channel]
h4 = tf.nn.conv2d_transpose(h3, self.gen_W4, output_shape=output_shape_l4, strides=[1,2,2,1])#shape=(128, 28, 28, 1)
x = tf.nn.sigmoid(h4)#shape=(128, 28, 28, 1)
return x
def samples_generator(self, batch_size):
#結構與generate()相同,bacth_size設置的值與其不同,是爲了將生成的一批圖像拼接到一張例圖上
Z = tf.placeholder(tf.float32, [batch_size, self.dim_z])
Y = tf.placeholder(tf.float32, [batch_size, self.dim_y])
yb = tf.reshape(Y, [batch_size, 1, 1, self.dim_y])
Z_ = tf.concat([Z,Y],1)
h1 = tf.nn.relu(batchnormalize(tf.matmul(Z_, self.gen_W1)))
h1 = tf.concat([h1, Y],1)
h2 = tf.nn.relu(batchnormalize(tf.matmul(h1, self.gen_W2)))
h2 = tf.reshape(h2, [batch_size,7,7,self.dim_W2])
h2 = tf.concat( [h2, yb*tf.ones([batch_size, 7, 7, self.dim_y])],3)
output_shape_l3 = [batch_size,14,14,self.dim_W3]
h3 = tf.nn.conv2d_transpose(h2, self.gen_W3, output_shape=output_shape_l3, strides=[1,2,2,1])
h3 = tf.nn.relu( batchnormalize(h3) )
h3 = tf.concat([h3, yb*tf.ones([batch_size, 14,14,self.dim_y])],3 )
output_shape_l4 = [batch_size,28,28,self.dim_channel]
h4 = tf.nn.conv2d_transpose(h3, self.gen_W4, output_shape=output_shape_l4, strides=[1,2,2,1])
x = tf.nn.sigmoid(h4)
return Z,Y,x
train.py文件用於訓練模型並生成樣例圖像
# !/usr/bin/python
#-*- coding: utf-8 -*-
import ipdb
import os
import pandas as pd
import numpy as np
from model import *
from util import *
from load import mnist_with_valid_set
n_epochs = 100
learning_rate = 0.0002
batch_size = 128
image_shape = [28,28,1]
dim_z = 100
dim_W1 = 1024
dim_W2 = 128
dim_W3 = 64
dim_channel = 1
visualize_dim=196#196個生成的樣本圖像拼湊在一起保存爲一張樣例圖,樣例圖的尺寸爲392x392
trX, vaX, teX, trY, vaY, teY = mnist_with_valid_set()
# trX trY 訓練集
# vaX vaY 驗證集
# teX teY 測試集
dcgan_model = DCGAN(
batch_size=batch_size,
image_shape=image_shape,
dim_z=dim_z,
dim_W1=dim_W1,
dim_W2=dim_W2,
dim_W3=dim_W3,
)
Z_tf, Y_tf, image_tf, d_cost_tf, g_cost_tf, p_real, p_gen = dcgan_model.build_model()
# Z_tf 輸入隨機的數據
# Y_tf 輸出的類別(0-9個數字)
# image_tf batch_size個真實訓練圖像
# d_cost_tf 判別網絡的目標函數
# g_cost_tf 生成網絡的目標函數
# p_real 真實圖像的判別概率
# p_gen 生成圖像的判別概率
sess = tf.InteractiveSession() #交互式使用對話
saver = tf.train.Saver(max_to_keep=10)
#得到判別網絡和生成網絡的參數列表
discrim_vars = filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables())
gen_vars = filter(lambda x: x.name.startswith('gen'), tf.trainable_variables())
#訓練使用Adam進行優化,學習率設爲0.0002
train_op_discrim = tf.train.AdamOptimizer(learning_rate, beta1=0.5).minimize(d_cost_tf, var_list=discrim_vars)
train_op_gen = tf.train.AdamOptimizer(learning_rate, beta1=0.5).minimize(g_cost_tf, var_list=gen_vars)
Z_tf_sample, Y_tf_sample, image_tf_sample = dcgan_model.samples_generator(batch_size=visualize_dim)
tf.initialize_all_variables().run()
Z_np_sample = np.random.uniform(-1, 1, size=(visualize_dim,dim_z))#從-1到1的均勻分佈中隨機採樣,得到196x100的矩陣
Y_np_sample = OneHot( np.random.randint(10, size=[visualize_dim]))#生成196x10的矩陣,對於每一行,隨機地在某一列置1
iterations = 0
k = 2
for epoch in range(n_epochs):
index = range(len(trY))#得到訓練數據的下標序列
np.random.shuffle(index)#置亂下標
#得到新的訓練數據序列
trX = trX[index]
trY = trY[index]
for start, end in zip(
range(0, len(trY), batch_size),#生成一個從0開始,至訓練數據數量大小結尾,間隔是128的序列,近似於[0,128,256,...]
range(batch_size, len(trY), batch_size)#與上一行得到的list相比,少了開始的一個0,[128,256,384,...]
):
#zip生成一個序列對[(0, 128), (128, 256), (256, 384),...],每一對相當於一個batch_size
Xs = trX[start:end].reshape( [-1, 28, 28, 1]) / 255.#像素值限制在1以下
Ys = OneHot(trY[start:end])#生成128x10的矩陣,對於每一行,隨機地在某一列置1
Zs = np.random.uniform(-1, 1, size=[batch_size, dim_z]).astype(np.float32)#生成128組維數爲100的隨機向量
#先訓練判別網絡的參數,再訓練生成網絡中的參數,交替進行
if np.mod( iterations, k ) != 0:
#喂進去Z、Y,得到生成圖像;得到生成圖像的判別概率;得到生成網絡的目標函數,訓練優化其中的參數
_, gen_loss_val = sess.run(
[train_op_gen, g_cost_tf],
feed_dict={
Z_tf:Zs,
Y_tf:Ys
})
#喂進去Z、X、Y,得到生成圖像,得到真實圖像;得到生成圖像判別概率和真實圖像判別概率;得到判別網絡的目標函數(未優化)
discrim_loss_val, p_real_val, p_gen_val = sess.run([d_cost_tf,p_real,p_gen], feed_dict={Z_tf:Zs, image_tf:Xs, Y_tf:Ys})
print "=========== updating G =========="
print "iteration:", iterations
print "gen loss:", gen_loss_val
print "discrim loss:", discrim_loss_val
else:
# 喂進去Z、X、Y,得到生成圖像,得到真實圖像;得到生成圖像判別概率和真實圖像判別概率;得到判別網絡的目標函數,訓練優化其中的參數
_, discrim_loss_val = sess.run(
[train_op_discrim, d_cost_tf],
feed_dict={
Z_tf:Zs,
Y_tf:Ys,
image_tf:Xs
})
#喂進去Z、Y,得到生成圖像;得到生成圖像的判別概率;得到生成網絡的目標函數(未優化)
gen_loss_val, p_real_val, p_gen_val = sess.run([g_cost_tf, p_real, p_gen], feed_dict={Z_tf:Zs, image_tf:Xs, Y_tf:Ys})
print "=========== updating D =========="
print "iteration:", iterations
print "gen loss:", gen_loss_val
print "discrim loss:", discrim_loss_val
print "Average P(real)=", p_real_val.mean()
print "Average P(gen)=", p_gen_val.mean()
#每200次epoch生成一張樣例圖。由於samples_generator中的參數和generate中的參數是一樣的,直接餵給其數據,就能享受到訓練了200*n輪後的成果
if np.mod(iterations, 200) == 0:
generated_samples = sess.run(
image_tf_sample,
feed_dict={
Z_tf_sample:Z_np_sample,
Y_tf_sample:Y_np_sample
})
generated_samples = (generated_samples + 1.)/2.#把像素值的範圍由在-1到1轉化成0到1
save_visualization(generated_samples, (14,14), save_path='./vis/sample_'+str(iterations/200)+'.jpg')
iterations += 1
對於此項目,我還存在一些疑惑,比如說網絡結構中Y(標籤)的數量10爲什麼總被用於改變網絡的結構:
h2 = tf.reshape(h2,[self.batch_size,7,7,self.dim_W2])#shape =(128,7,7,128)
h2 = tf.concat([h2,yb * tf.ones([self.batch_size,7,7,self.dim_y])],3)#shape =(128,7,7,138)
在這裏,h2增加了10個通道。它的作用是什麼呢?