MXNET深度學習框架-30-樣式遷移

        如果你喜歡拍照,可能接觸過濾鏡(如下圖)。它能改變照片的顏色樣式,從而使風景照更加銳利或者令人像更加美白。但一個濾鏡通常只能改變照片的某個方面。如果要照片達到理想中的樣式,經常需要嘗試大量不同的組合,其複雜程度不亞於模型調參。

在這裏插入圖片描述
具體的原理本章就不做介紹了,請讀者自行百度或閱讀相關論文。

基本結構
在這裏插入圖片描述
        上圖中,分別使用CNN對樣式圖片和內容圖片抽取特徵,比如,通過第1、2、4層卷積抽取樣式圖片特徵,通過第3層卷積抽取內容圖片特徵,然後通過正向傳播計算樣式遷移的損失函數,通過反向傳播迭代模型參數,即不斷更新合成圖像。樣式遷移常用的損失函數由3部分組成:內容損失(content loss)使合成圖像與內容圖像在內容特徵上接近,樣式損失(style loss)令合成圖像與樣式圖像在樣式特徵上接近,而總變差損失(total variation loss)則有助於減少合成圖像中的噪點。最後,當模型訓練結束時,我們輸出樣式遷移的模型參數,即得到最終的合成圖像。一般步驟如下:
        1)挑選網絡特定層分別作爲樣式層和內容層;
        2)輸入樣式圖片並保存樣式層輸出,記爲sis_i;
        3) 輸入內容圖片並保存內容層輸出,記爲cic_i
        4) 初始化合成圖片X爲隨機值,然後進行迭代,使得用xx抽取的特徵能夠匹配上sis_icic_i,具體來說,我們如下迭代直至收斂:
                a) 輸入xx計算樣式層和內容層輸出,記ii層輸出爲yiy_i
                b) 使用樣式損失函數計算yiy_isis_i的差異;
                c) 使用內容損失函數計算yiy_icic_i的差異;
                d) 對損失求和並對xx求導,記導數爲gg
                e) 更新xx,例如xηgx-\eta g

        內容損失函數通常使用迴歸的均方誤差,對於樣式,我們通常將它看成是像素點在每個通道的統計分佈。例如要匹配兩張圖片的顏色,我們的一個做法是匹配這兩張圖片在RGB三個通道的直方圖,更一般的,假設卷積輸出的格式是c×h×wc×h×w,那麼我們可以把它變形成一個c×hwc×hw的2D矩陣,並將它看成是一個維度爲cc的隨機變量採樣到的hwhw個點。所謂的樣式匹配就是使得兩個cc維隨機變量統計分佈一致。

        匹配統計分佈常用的做法是衝量匹配,就是說使它們有一樣的均值、協方差和其它高維的衝量。爲了計算簡單起見,我們這裏假設卷積輸出已經是均值爲0了,而且,我們只匹配協方差,也就是說,樣式損失函數就是對sis_iyiy_i計算Gram矩陣然後應用均方誤差:
styleloss(si,yi)=1c2hwsisiTyiyiTF2styleloss(s_i,y_i)=\frac{1}{c^2hw}\quad||s_is_i^T-y_iy_i^T||_F^2

        這裏假設我們已經將sis_iyiy_i變形成了c×hwc×hw的2D矩陣了。

        下面我們將實現這個算法來深入理解各個參數,例如樣式層和內容層的選取,對實際結果的影響。

1、數據

樣式圖片:
在這裏插入圖片描述
內容圖片:
在這裏插入圖片描述

2、數據處理

rgb_mean=nd.array([0.485,0.456,0.406])
rgb_std=nd.array([0.229,0.224,0.225])
def image_preprocess(image,input_shape):
    image=cv.resize(image,*input_shape)
    img=(nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std
    return img.transpose((2,0,1)).expand_dims(axis=0)

def postprocess(image):# 後處理圖像
    new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
    new_image=(new_image*255).asnumpy().astype(np.uint8)
    return new_image

3、模型

        本章我們使用的是VGG-19的預訓練模型,它已經在ImageNet上做過預訓練了。

pretrained_net = models.vgg19(pretrained=True)
print(pretrained_net)

模型結構:
在這裏插入圖片描述在這裏插入圖片描述在這裏插入圖片描述
        從結構上來看,VGG-19有5個卷積塊(5個池化),這裏我們按照原論文的每個卷積塊的第一個卷積輸出來匹配樣式,第四個卷積塊的最後一個卷積輸出來匹配內容。根據打印的結果我們來記錄下位置:

style_layers=[0,5,10,19,28]
content_layers=[25]

4、構建新網絡

        因爲只需要用到中間層的輸出,所以需要構建新的網絡,它只保留需要用到的VGG的所有層:

def get_net(pretrained_net,content_layers,style_layers):
    net=gn.nn.Sequential()
    for i in range(max(content_layers+style_layers)+1):
        # print(i)
        net.add(pretrained_net.features[i])
    return net
net=get_net(pretrained_net,content_layers,style_layers)

給定輸入x,我們能拿到最後一層的輸出,現在我們需要拿到指定層的輸出:

def extract_features(x,content_layers,style_layers):
    contents=[]
    styles=[]
    for i in range(len(net)):
        x=net[i](x)
        if i in style_layers:
            styles.append(x)
        if i in content_layers:
            contents.append(x)
    return contents, styles

5、損失函數

        內容匹配是一個迴歸問題,這裏使用均方誤差來比較內容層的輸出:
1) 內容匹配損失函數

def content_loss(yhat,y):
    return (yhat-y).square().mean()

2)樣式匹配損失函數,通過擬合Gram矩陣

def gram(x):
    # 根據公式,獲得channel
    c, h, w = x.shape[1], x.shape[2], x.shape[3]
    hw = h * w
    y = x.reshape((c, int(hw)))
    return nd.dot(y, y.T) / hw # 除hw爲了normliza一下

3)總變差降噪
        當使用輸出層的高層輸出來擬合時,會發現學到的圖片中含有大量噪聲,對於解決噪聲,可以利用模糊濾鏡(中值濾波、高斯濾波等,但會使圖片邊緣信息也模糊),或者總變差降噪。

        假設xi,jx_{i,j}表示每個像素,那麼加入下面的損失函數,它使得鄰近的像素值相似:
i,jxi,jxi+1,j+xi,jxi,j+1 \sum_{i,j}^\infty |x_{i,j}-x_{i+1,j}|+|x_{i,j}-x_{i,j+1}|
所以,根據公式有:

def tv_loss(yhat):
    return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
                  (yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())

        總損失函數是上述三個損失函數的加權和,通過調整權重值,我們可以控制學到的圖片是否保留更多樣式、內容。注意到樣式匹配中我們使用了5個層的輸出,所以所以應該對靠近輸入層的層給予更大的權重。

channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel數
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10

def sum_loss(loss,preds,truths,weights):
    return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])

6、訓練

# 6、訓練
# 定義兩個函數,分別對樣式圖片和內容圖片提取特徵
def get_style(style_image,image_shape):
    style_x=image_preprocess(style_image,image_shape,)
    _,style_y=extract_features(style_x,content_layers,style_layers)
    style_y=[gram(y) for y in style_y]
    return style_x,style_y

def get_content(content_image,image_shape):
    content_x=image_preprocess(content_image,image_shape)
    content_y,_ = extract_features(content_x, content_layers, style_layers)

    return content_x,content_y

image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)

style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成圖像x,將合成圖片的初始值設爲(樣式)內容圖片來加速收斂
x.attach_grad()

def train(x,max_epochs,lr,lr_decay_epoch=200):
    for i in range(max_epochs):
        with ag.record():
            content_py,style_py=extract_features(x,content_layers,style_layers)

            content_L=sum_loss(content_loss,content_py,content_y,content_weight)
            style_L=sum_loss(style_loss,style_py,style_y,style_weights)
            tv_L=tv_loss(x)*tv_weight
            loss=content_L+style_L+tv_L
        loss.backward()
        x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它對lr沒那麼敏感
        x[:]-=lr*x.grad
        nd.waitall()
        if i and i%5==0:
            print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
                i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
            ))
        if i and i%lr_decay_epoch==0:
            lr*=0.1

y=train(x,1000,0.1)

訓練結果:
在這裏插入圖片描述

7、顯示合成的圖片

y=nd.squeeze(y,0).as_in_context(mx.cpu())

result=postprocess(y) # 顯示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()

讓我們來看看最終結果:
樣式圖片:
在這裏插入圖片描述
內容圖片:
在這裏插入圖片描述
合成的圖片:
在這裏插入圖片描述
        是不是有那麼點味道了?要想讓它更能學習到樣式和內容圖片的特徵,並且加速收斂,可以把初始化的x(合成圖片)設爲上面生成的合成圖片,讓上面生成的合成圖片去做權值更新。

下面放上所有代碼:

import mxnet.ndarray as nd
import mxnet as mx
import mxnet.gluon as gn
import mxnet.autograd as ag
from mxnet.gluon.model_zoo import vision as models
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
ctx=mx.gpu()
# 1、數據
style_image=cv.imread("style_content/style.jpeg")
content_image=cv.imread("style_content/content.jpg")

# 2、數據預處理
rgb_mean = nd.array([0.485, 0.456, 0.406])
rgb_std = nd.array([0.229, 0.224, 0.225])


def image_preprocess(image, input_shape):
    image = cv.resize(image, (input_shape[1],input_shape[0]))
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    img = (nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std

    return img.transpose((2, 0, 1)).expand_dims(axis=0).as_in_context(ctx)


def postprocess(image):# 後處理圖像
    new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
    new_image=(new_image*255).asnumpy().astype(np.uint8)
    return new_image


# 3、模型
pretrained_net = models.vgg19(pretrained=True,ctx=ctx)
print(pretrained_net)

style_layers = [0, 5, 10, 19, 28]
content_layers = [25]


# 4、構建新網絡
def get_net(pretrained_net, content_layers, style_layers):
    net = gn.nn.Sequential()
    for i in range(max(content_layers + style_layers) + 1):
        # print(i)
        net.add(pretrained_net.features[i])
    return net


net = get_net(pretrained_net, content_layers, style_layers)
net.collect_params().reset_ctx(ctx=ctx)

def extract_features(x, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        x = net[i](x)
        if i in style_layers:
            styles.append(x)
        if i in content_layers:
            contents.append(x)
    return contents, styles

# 5、損失函數
# 1)內容匹配
def content_loss(yhat, y):
    return (yhat - y).square().mean()

# 2)樣式匹配,先定義Gram矩陣
def gram(x):
    # 根據公式,獲得channel
    c, h, w = x.shape[1], x.shape[2], x.shape[3]
    hw = h * w
    y = x.reshape((c, int(hw)))
    return nd.dot(y, y.T) / hw # 除hw爲了normliza一下

def style_loss(yhat,gram_y):
    return (gram(yhat)-gram_y).square().mean()

def tv_loss(yhat):
    return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
                  (yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())

channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel數
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10

def sum_loss(loss,preds,truths,weights):
    return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])

# 6、訓練
# 定義兩個函數,分別對樣式圖片和內容圖片提取特徵
def get_style(style_image,image_shape):
    style_x=image_preprocess(style_image,image_shape,)
    _,style_y=extract_features(style_x,content_layers,style_layers)
    style_y=[gram(y) for y in style_y]
    return style_x,style_y

def get_content(content_image,image_shape):
    content_x=image_preprocess(content_image,image_shape)
    content_y,_ = extract_features(content_x, content_layers, style_layers)

    return content_x,content_y

image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)

style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成圖像x,將合成圖片的初始值設爲(樣式)內容圖片來加速收斂
x.attach_grad()

def train(x,max_epochs,lr,lr_decay_epoch=200):
    for i in range(max_epochs):
        with ag.record():
            content_py,style_py=extract_features(x,content_layers,style_layers)

            content_L=sum_loss(content_loss,content_py,content_y,content_weight)
            style_L=sum_loss(style_loss,style_py,style_y,style_weights)
            tv_L=tv_loss(x)*tv_weight
            loss=content_L+style_L+tv_L
        loss.backward()
        x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它對lr沒那麼敏感
        x[:]-=lr*x.grad
        nd.waitall()
        if i and i%5==0:
            print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
                i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
            ))
        if i and i%lr_decay_epoch==0:
            lr*=0.1
    return x


y=train(x,1000,0.1)
# 7、顯示合成圖像
y=nd.squeeze(y,0).as_in_context(mx.cpu())

result=postprocess(y) # 顯示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章