tensorflow風格遷移網絡訓練與使用

風格遷移原理解釋

卷積神經網絡實現圖像風格遷移在2015的一篇論文中最早出現。實現了一張從一張圖像中提取分割,從另外一張圖像中提取內容,疊加生成一張全新的圖像。早前風靡一時的風格遷移APP – Prisma其背後就是圖像各種風格遷移、讓人耳目一新。其主要的思想是對於訓練好的卷積神經網絡,其內部一些feature map跟最終識別的對象是特徵獨立的,這些特徵當中有一些是關於內容特徵的,另外一些是關於風格特徵的,於是我們可以輸入兩張圖像,從其中一張圖像上提取其內容特徵,另外一張圖像上提取其風格特徵,然後把它們疊加在一起形成一張新的圖像,這個就風格遷移卷積網絡。最常見的我們是用一個預先訓練好的卷積神經網絡,常見的就是VGG-19,其結構如下:

其包含16個卷積層、5個池化層、3個全鏈接層。其中:

表示內容層爲:relu4-2 表示風格層爲:relu1_1, relu2_1, relu3_1, relu4_1, relu5_1

越高階的層圖像內容越抽象,我們損失的像素信息越多,所有選用relu4-2層作爲內容層而忽略低階的內容損失,對於風格來說,它是從低階到高階的層組合。所以選用從低到高不同層組合作爲風格[relu1_1, relu2_1, relu3_1, relu4_1, relu5_1]

遷移損失

風格遷移生成圖像Y, 要求它的內容來自圖像C, 要求它的風格來自圖像S。

Y是隨機初始化的一張圖像,帶入到預訓練的網絡中會得到內容層與風格層的輸出結果 C是內容圖像,帶入到預訓練的網絡中得到內容層Target標籤 S是風格圖像,帶入到預訓練的網絡中得到風格層Target標籤

這樣總的損失函數就是內容與標籤的損失,此外我們希望最終生成的圖像是一張光滑圖像,所有還有一個像素方差損失,這三個損失分別表示爲 :

Loss(content)、 Loss(style) 、 Loss(var)

最終總的損失函數爲:

Total Loss = alpha * Loss (content) + beta * Loss (Style) + Loss (var) 其中alpha與beta分別是內容損失與風格損失的權重大小

代碼實現:

獲取內容圖像C與風格圖像S的標籤

# Get network parameters
image = tf.placeholder('float', shape=shape)
vgg_net = vgg_network(network_weights, image)

# Normalize original image
original_minus_mean = content_image - normalization_mean
original_norm = np.array([original_minus_mean])
original_features[content_layers] = sess.run(vgg_net[content_layers],
                                             feed_dict={image: original_norm})

# Get style image network
image = tf.placeholder('float', shape=style_shape)
vgg_net = vgg_network(network_weights, image)
style_minus_mean = style_image - normalization_mean
style_norm = np.array([style_minus_mean])

for layer in style_layers:
    layer_output = sess.run(vgg_net[layer], feed_dict={image: style_norm})
    layer_output = np.reshape(layer_output, (-1, layer_output.shape[3]))
    style_gram_matrix = np.matmul(layer_output.T, layer_output) / layer_output.size
    style_features[layer] = style_gram_matrix

隨機初始化Y圖像

#  隨機初始化目標圖像
initial = tf.random_normal(shape) * 0.256
image = tf.Variable(initial)
vgg_net = vgg_network(network_weights, image)

計算內容損失

# 計算目標圖像內容與內容圖像之間的差異, 內容損失
original_loss = original_image_weight * (
2 * tf.nn.l2_loss(vgg_net[content_layers] - original_features[content_layers]) /
original_features[content_layers].size)

計算風格損失

# 風格損失
style_loss = 0
style_losses = []
for style_layer in style_layers:
    layer = vgg_net[style_layer]
    feats, height, width, channels = [x.value for x in layer.get_shape()]
    size = height * width * channels
    features = tf.reshape(layer, (-1, channels))
    style_gram_matrix = tf.matmul(tf.transpose(features), features) / size
    style_expected = style_features[style_layer]
    style_losses.append(2 * tf.nn.l2_loss(style_gram_matrix - style_expected) / style_expected.size)
style_loss += style_image_weight * tf.reduce_sum(style_losses)

添加smooth損失

# To Smooth the resuts, we add in total variation loss
total_var_x = sess.run(tf.reduce_prod(image[:, 1:, :, :].get_shape()))
total_var_y = sess.run(tf.reduce_prod(image[:, :, 1:, :].get_shape()))
first_term = regularization_weight * 2
second_term = (tf.nn.l2_loss(image[:, 1:, :, :] - image[:, :shape[1] - 1, :, :]) / total_var_y)
third_term = (tf.nn.l2_loss(image[:, :, 1:, :] - image[:, :, :shape[2] - 1, :]) / total_var_x)
total_variation_loss = first_term * (second_term + third_term)

訓練風格遷移

# 總的損失
loss = original_loss + style_loss + total_variation_loss

# 優化器
optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2)
train_step = optimizer.minimize(loss)

# 初始化參數與訓練
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for i in range(generations):
    sess.run(train_step)
    # Print update and save temporary output
    if (i + 1) % output_generations == 0:
        print('Generation {} out of {}, loss: {}'.format(i + 1, generations, sess.run(loss)))
        image_eval = sess.run(image)
        best_image = image_eval.reshape(shape[1:]) + normalization_mean
        temp_img = np.clip(best_image, 0, 255).astype(np.uint8)
        output_file = 'D:/pet_data/temp_output_{}.jpg'.format(i)
        Image.fromarray(temp_img).save(output_file, quality=95)
saver.save(sess, "./neural_style.model", global_step=2500)

運行結果

輸入圖像(右下角爲風格圖像),輸出圖像

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章