【代碼閱讀】WarpGAN: Automatic Caricature Generation

代碼鏈接

參考書籍:《Tensorflow 實戰Google深度學習框架》

我覺得看一下第三章可以更清晰的瞭解tensorflow是怎麼建立,訓練一個神經網絡的。

1. train.py

這份文件定義了主函數

def main(args):

初始化:

# Initalization for running
    if config.save_model:
        log_dir = utils.create_log_dir(config, config_file)
        summary_writer = tf.summary.FileWriter(log_dir, network.graph)
    if config.restore_model:
        network.restore_model(config.restore_model, config.restore_scopes)

    proc_func = lambda images: preprocess(images, config, True)
    trainset.start_batch_queue(config.batch_size, proc_func=proc_func)

這裏的config參數設置都來自於文件 WarpGAN\config\default.py

數據集讀取初始化等操作來自於文件 WarpGAN\utils\dataset.py

主循環:

# Main Loop
    print('\nStart Training\nname: {}\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.format(
            config.name, config.num_epochs, config.epoch_size, config.batch_size))
    global_step = 0
    start_time = time.time()
    for epoch in range(config.num_epochs):

        if epoch == 0: test(network, config, log_dir, global_step)

        # Training
        for step in range(config.epoch_size):
            # Prepare input
            learning_rate = utils.get_updated_learning_rate(global_step, config)
            batch = trainset.pop_batch_queue()

            wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

            wl['lr'] = learning_rate

            # Display
            if step % config.summary_interval == 0:
                duration = time.time() - start_time
                start_time = time.time()
                utils.display_info(epoch, step, duration, wl)
                if config.save_model:
                    summary_writer.add_summary(sm, global_step=global_step)

wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)

這句話是重點,調用了網絡的訓練

2. warpgan.py

這個文件中定義了warpgan這個網絡的計算圖,前向傳播以及損失函數。

訓練神經網絡的過程可以概括爲下面這三個步驟:

1)定義神經網絡的結構和前向傳播的輸出結果

2)定義損失函數(根據前向傳播的輸出結果計算出來的)以及反向傳播優化的算法

3)生成會話(tf.Session()),並在訓練數據上反覆運行反向傳播優化算法

    def train(self, images_batch, labels_batch, switch_batch, learning_rate, keep_prob):
        images_A = images_batch[~switch_batch]
        images_B = images_batch[switch_batch]
        labels_A = labels_batch[~switch_batch]
        labels_B = labels_batch[switch_batch]
        scales_A = np.ones((images_A.shape[0]))
        scales_B = np.ones((images_B.shape[0]))
        feed_dict = {   self.images_A: images_A,
                        self.images_B: images_B,
                        self.labels_A: labels_A,
                        self.labels_B: labels_B,
                        self.scales_A: scales_A,
                        self.scales_B: scales_B,
                        self.learning_rate: learning_rate,
                        self.keep_prob: keep_prob,
                        self.phase_train: True,}
        _, wl, sm = self.sess.run([self.train_op, self.watch_list, self.summary_op], feed_dict = feed_dict)

        step = self.sess.run(self.global_step)

        return wl, sm, step

train函數被上面的train.py調用,是生成會話這個步驟

其中的 self.train_op, self.watch_list, self.summary_op, self.global_step 分別是幾個運算。

我們主要關注self.train_op這個運算(更新參數)

它在initialize 這個函數中定義,這個函數定義了前向傳播和損失函數

    def initialize(self, config, num_classes=None):
        '''
            Initialize the graph from scratch according to config.
        '''
        with self.graph.as_default():
            with self.sess.as_default():
                # Set up placeholders
                h, w = config.image_size
                channels = config.channels
                self.images_A = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_A')
                self.images_B = tf.placeholder(tf.float32, shape=[None, h, w, channels], name='images_B')
                self.labels_A = tf.placeholder(tf.int32, shape=[None], name='labels_A')
                self.labels_B = tf.placeholder(tf.int32, shape=[None], name='labels_B')
                self.scales_A = tf.placeholder(tf.float32, shape=[None], name='scales_A')
                self.scales_B = tf.placeholder(tf.float32, shape=[None], name='scales_B')

                self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
                self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
                self.phase_train = tf.placeholder(tf.bool, name='phase_train')
                self.global_step = tf.Variable(0, trainable=False, dtype=tf.int32, name='global_step')

                self.setup_network_model(config, num_classes)

                # Build generator
                encode_A, styles_A = self.encoder(self.images_A)
                encode_B, styles_B = self.encoder(self.images_B)

                deform_BA, render_BA, ldmark_pred, ldmark_diff = self.decoder(encode_B, self.scales_B, None)
                render_AA = self.decoder(encode_A, self.scales_A, styles_A, texture_only=True)
                render_BB = self.decoder(encode_B, self.scales_B, styles_B, texture_only=True)

                self.styles_A = tf.identity(styles_A, name='styles_A')
                self.styles_B = tf.identity(styles_B, name='styles_B')
                self.deform_BA = tf.identity(deform_BA, name='deform_BA')
                self.ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
                self.ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')


                # Build discriminator for real images
                patch_logits_A, logits_A = self.discriminator(self.images_A)
                patch_logits_B, logits_B = self.discriminator(self.images_B)
                patch_logits_BA, logits_BA = self.discriminator(deform_BA)                          

                # Show images in TensorBoard
                image_grid_A = tf.stack([self.images_A, render_AA], axis=1)[:1]
                image_grid_B = tf.stack([self.images_B, render_BB], axis=1)[:1]
                image_grid_BA = tf.stack([self.images_B, deform_BA], axis=1)[:1]
                image_grid = tf.concat([image_grid_A, image_grid_B, image_grid_BA], axis=0)
                image_grid = tf.reshape(image_grid, [-1] + list(self.images_A.shape[1:]))
                image_grid = self.image_grid(image_grid, (3,2))
                tf.summary.image('image_grid', image_grid)


                # Build all losses
                self.watch_list = {}
                loss_list_G  = []
                loss_list_D  = []
               
                # Advesarial loss for deform_BA
                loss_D, loss_G = self.cls_adv_loss(logits_A, logits_B, logits_BA,
                    self.labels_A, self.labels_B, self.labels_B, num_classes)
                loss_D, loss_G = config.coef_adv*loss_D, config.coef_adv*loss_G

                self.watch_list['LDg'] = loss_D
                self.watch_list['LGg'] = loss_G
                loss_list_D.append(loss_D)
                loss_list_G.append(loss_G)

                # Patch Advesarial loss for deform_BA
                loss_D, loss_G = self.patch_adv_loss(patch_logits_A, patch_logits_B, patch_logits_BA)
                loss_D, loss_G = config.coef_patch_adv*loss_D, config.coef_patch_adv*loss_G

                self.watch_list['LDp'] = loss_D
                self.watch_list['LGp'] = loss_G
                loss_list_D.append(loss_D)
                loss_list_G.append(loss_G)

                # Identity Mapping (Reconstruction) loss
                loss_idt_A = tf.reduce_mean(tf.abs(render_AA - self.images_A), name='idt_loss_A')
                loss_idt_A = config.coef_idt * loss_idt_A

                loss_idt_B = tf.reduce_mean(tf.abs(render_BB - self.images_B), name='idt_loss_B')
                loss_idt_B = config.coef_idt * loss_idt_B

                self.watch_list['idtA'] = loss_idt_A
                self.watch_list['idtB'] = loss_idt_B
                loss_list_G.append(loss_idt_A+loss_idt_B)


                # Collect all losses
                reg_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), name='reg_loss')
                self.watch_list['reg_loss'] = reg_loss
                loss_list_G.append(reg_loss)
                loss_list_D.append(reg_loss)


                loss_G = tf.add_n(loss_list_G, name='loss_G')
                grads_G = tf.gradients(loss_G, self.G_vars)

                loss_D = tf.add_n(loss_list_D, name='loss_D')
                grads_D = tf.gradients(loss_D, self.D_vars)

                # Training Operaters
                train_ops = []

                opt_G = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
                opt_D = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5, beta2=0.9)
                apply_G_gradient_op = opt_G.apply_gradients(list(zip(grads_G, self.G_vars)))
                apply_D_gradient_op = opt_D.apply_gradients(list(zip(grads_D, self.D_vars)))

                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                train_ops.extend([apply_G_gradient_op, apply_D_gradient_op] + update_ops)

                train_ops.append(tf.assign_add(self.global_step, 1))
                self.train_op = tf.group(*train_ops)

                # Collect TF summary
                for k,v in self.watch_list.items():
                    tf.summary.scalar('losses/' + k, v)
                tf.summary.scalar('learning_rate', self.learning_rate)
                self.summary_op = tf.summary.merge_all()

                # Initialize variables
                self.sess.run(tf.local_variables_initializer())
                self.sess.run(tf.global_variables_initializer())
                self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=99)

 

根據這裏所定義的前向傳播,我們可以畫出下面這張前向傳播圖:

和論文中的網絡結構圖相結合,我們可以畫出下面這張詳細過程圖:

3. WarpGAN\models\default.py

這個文件中定義了encoder,decoder,discriminator三個網絡的詳細結構供 warpgan.py調用

爲了找到特徵點是如何訓練出來的,我主要看了屬於decoder的warpcontoller這個生成臉部特徵點的子網絡

                    with tf.variable_scope('WarpController'):

                        print('-- WarpController')

                        net = encoded
                        warp_input = tf.identity(images_rendered, name='warp_input')

                        net = slim.flatten(net)

                        net = slim.fully_connected(net, 128, scope='fc1')
                        print('module fc1 shape:', [dim.value for dim in net.shape])

                        num_ldmark = 16

                        # Predict the control points
                        ldmark_mean = (np.random.normal(0,50, (num_ldmark,2)) + np.array([[0.5*h,0.5*w]])).flatten()
                        ldmark_mean = tf.Variable(ldmark_mean.astype(np.float32), name='ldmark_mean')
                        print('ldmark_mean shape:', [dim.value for dim in ldmark_mean.shape])

                        ldmark_pred = slim.fully_connected(net, num_ldmark*2, 
                            weights_initializer=tf.truncated_normal_initializer(stddev=1.0),
                            normalizer_fn=None, activation_fn=None, biases_initializer=None, scope='fc_ldmark')
                        ldmark_pred = ldmark_pred + ldmark_mean
                        print('ldmark_pred shape:', [dim.value for dim in ldmark_pred.shape])
                        ldmark_pred = tf.identity(ldmark_pred, name='ldmark_pred')
                 

                        # Predict the displacements
                        ldmark_diff = slim.fully_connected(net, num_ldmark*2, 
                            normalizer_fn=None,  activation_fn=None, scope='fc_diff')
                        print('ldmark_diff shape:', [dim.value for dim in ldmark_diff.shape])
                        ldmark_diff = tf.identity(ldmark_diff, name='ldmark_diff')
                        ldmark_diff = tf.identity(tf.reshape(scales,[-1,1]) * ldmark_diff, name='ldmark_diff_scaled')



                        src_pts = tf.reshape(ldmark_pred, [-1, num_ldmark ,2])
                        dst_pts = tf.reshape(ldmark_pred + ldmark_diff, [-1, num_ldmark, 2])

                        diff_norm = tf.reduce_mean(tf.norm(src_pts-dst_pts, axis=[1,2]))
                        # tf.summary.scalar('diff_norm', diff_norm)
                        # tf.summary.scalar('mark', ldmark_pred[0,0])

                        images_transformed, dense_flow = sparse_image_warp(warp_input, src_pts, dst_pts,
                                regularization_weight = 1e-6, num_boundary_points=0)
                        dense_flow = tf.identity(dense_flow, name='dense_flow')

我的理解如下:

1)特徵點:ldmark_mean+ldmark_pred, 每次迭代ldmark_mean均爲由中心加上某個隨機數生成的隨機點,ldmark_pred在網絡中更新後(網絡的輸入是經過encode的圖片)再加上ldmark_mean作爲ldmark_pred的一步更新結果。

2)特徵點移動距離:由encode後的圖片經過全連接網絡得到

論文中關於變形和風格遷移之間的聯繫解釋如下(翻譯):

不同於其他視覺風格的轉換任務,本文將照片轉換成漫畫既涉及到紋理差異,也涉及幾何座標轉換。紋理是在誇大局部細粒度特徵,如皺紋的深度;而幾何變形允許誇大整體特徵,如面部形狀。傳統風格的傳輸網絡旨在使用解碼器網絡從特徵空間重構圖像。由於解碼器是一組非線性局部濾波器,其本質上受空間變化的影響,當輸入域和輸出域之間存在較大的幾何差異時,解碼器的圖像質量較差,信息丟失嚴重。另一方面,基於翹曲的方法受限於無法更改內容和細粒度細節。因此,風格轉換和變形模塊都是我們的學習框架中必不可少的部分。

如下圖所示,沒有任何一個模塊,生成器將無法縮小照片和漫畫之間的差距,而生成器和鑑別器之間對抗的平衡將被破壞,從而導致崩潰的結果。

因此這篇文章中的風格轉變和變形必須是同時進行的,不能只單單找到特徵點改變形狀。

我最近看到另一篇文章,CariGANs,它和這篇文章一樣,也是根據特徵點對臉部進行變形,我覺得之後還可以繼續看一下。

 

 

發佈了101 篇原創文章 · 獲贊 7 · 訪問量 2萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章