MXNET深度学习框架-30-样式迁移

        如果你喜欢拍照,可能接触过滤镜(如下图)。它能改变照片的颜色样式,从而使风景照更加锐利或者令人像更加美白。但一个滤镜通常只能改变照片的某个方面。如果要照片达到理想中的样式,经常需要尝试大量不同的组合,其复杂程度不亚于模型调参。

在这里插入图片描述
具体的原理本章就不做介绍了,请读者自行百度或阅读相关论文。

基本结构
在这里插入图片描述
        上图中,分别使用CNN对样式图片和内容图片抽取特征,比如,通过第1、2、4层卷积抽取样式图片特征,通过第3层卷积抽取内容图片特征,然后通过正向传播计算样式迁移的损失函数,通过反向传播迭代模型参数,即不断更新合成图像。样式迁移常用的损失函数由3部分组成:内容损失(content loss)使合成图像与内容图像在内容特征上接近,样式损失(style loss)令合成图像与样式图像在样式特征上接近,而总变差损失(total variation loss)则有助于减少合成图像中的噪点。最后,当模型训练结束时,我们输出样式迁移的模型参数,即得到最终的合成图像。一般步骤如下:
        1)挑选网络特定层分别作为样式层和内容层;
        2)输入样式图片并保存样式层输出,记为sis_i;
        3) 输入内容图片并保存内容层输出,记为cic_i
        4) 初始化合成图片X为随机值,然后进行迭代,使得用xx抽取的特征能够匹配上sis_icic_i,具体来说,我们如下迭代直至收敛:
                a) 输入xx计算样式层和内容层输出,记ii层输出为yiy_i
                b) 使用样式损失函数计算yiy_isis_i的差异;
                c) 使用内容损失函数计算yiy_icic_i的差异;
                d) 对损失求和并对xx求导,记导数为gg
                e) 更新xx,例如xηgx-\eta g

        内容损失函数通常使用回归的均方误差,对于样式,我们通常将它看成是像素点在每个通道的统计分布。例如要匹配两张图片的颜色,我们的一个做法是匹配这两张图片在RGB三个通道的直方图,更一般的,假设卷积输出的格式是c×h×wc×h×w,那么我们可以把它变形成一个c×hwc×hw的2D矩阵,并将它看成是一个维度为cc的随机变量采样到的hwhw个点。所谓的样式匹配就是使得两个cc维随机变量统计分布一致。

        匹配统计分布常用的做法是冲量匹配,就是说使它们有一样的均值、协方差和其它高维的冲量。为了计算简单起见,我们这里假设卷积输出已经是均值为0了,而且,我们只匹配协方差,也就是说,样式损失函数就是对sis_iyiy_i计算Gram矩阵然后应用均方误差:
styleloss(si,yi)=1c2hwsisiTyiyiTF2styleloss(s_i,y_i)=\frac{1}{c^2hw}\quad||s_is_i^T-y_iy_i^T||_F^2

        这里假设我们已经将sis_iyiy_i变形成了c×hwc×hw的2D矩阵了。

        下面我们将实现这个算法来深入理解各个参数,例如样式层和内容层的选取,对实际结果的影响。

1、数据

样式图片:
在这里插入图片描述
内容图片:
在这里插入图片描述

2、数据处理

rgb_mean=nd.array([0.485,0.456,0.406])
rgb_std=nd.array([0.229,0.224,0.225])
def image_preprocess(image,input_shape):
    image=cv.resize(image,*input_shape)
    img=(nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std
    return img.transpose((2,0,1)).expand_dims(axis=0)

def postprocess(image):# 后处理图像
    new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
    new_image=(new_image*255).asnumpy().astype(np.uint8)
    return new_image

3、模型

        本章我们使用的是VGG-19的预训练模型,它已经在ImageNet上做过预训练了。

pretrained_net = models.vgg19(pretrained=True)
print(pretrained_net)

模型结构:
在这里插入图片描述在这里插入图片描述在这里插入图片描述
        从结构上来看,VGG-19有5个卷积块(5个池化),这里我们按照原论文的每个卷积块的第一个卷积输出来匹配样式,第四个卷积块的最后一个卷积输出来匹配内容。根据打印的结果我们来记录下位置:

style_layers=[0,5,10,19,28]
content_layers=[25]

4、构建新网络

        因为只需要用到中间层的输出,所以需要构建新的网络,它只保留需要用到的VGG的所有层:

def get_net(pretrained_net,content_layers,style_layers):
    net=gn.nn.Sequential()
    for i in range(max(content_layers+style_layers)+1):
        # print(i)
        net.add(pretrained_net.features[i])
    return net
net=get_net(pretrained_net,content_layers,style_layers)

给定输入x,我们能拿到最后一层的输出,现在我们需要拿到指定层的输出:

def extract_features(x,content_layers,style_layers):
    contents=[]
    styles=[]
    for i in range(len(net)):
        x=net[i](x)
        if i in style_layers:
            styles.append(x)
        if i in content_layers:
            contents.append(x)
    return contents, styles

5、损失函数

        内容匹配是一个回归问题,这里使用均方误差来比较内容层的输出:
1) 内容匹配损失函数

def content_loss(yhat,y):
    return (yhat-y).square().mean()

2)样式匹配损失函数,通过拟合Gram矩阵

def gram(x):
    # 根据公式,获得channel
    c, h, w = x.shape[1], x.shape[2], x.shape[3]
    hw = h * w
    y = x.reshape((c, int(hw)))
    return nd.dot(y, y.T) / hw # 除hw为了normliza一下

3)总变差降噪
        当使用输出层的高层输出来拟合时,会发现学到的图片中含有大量噪声,对于解决噪声,可以利用模糊滤镜(中值滤波、高斯滤波等,但会使图片边缘信息也模糊),或者总变差降噪。

        假设xi,jx_{i,j}表示每个像素,那么加入下面的损失函数,它使得邻近的像素值相似:
i,jxi,jxi+1,j+xi,jxi,j+1 \sum_{i,j}^\infty |x_{i,j}-x_{i+1,j}|+|x_{i,j}-x_{i,j+1}|
所以,根据公式有:

def tv_loss(yhat):
    return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
                  (yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())

        总损失函数是上述三个损失函数的加权和,通过调整权重值,我们可以控制学到的图片是否保留更多样式、内容。注意到样式匹配中我们使用了5个层的输出,所以所以应该对靠近输入层的层给予更大的权重。

channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel数
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10

def sum_loss(loss,preds,truths,weights):
    return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])

6、训练

# 6、训练
# 定义两个函数,分别对样式图片和内容图片提取特征
def get_style(style_image,image_shape):
    style_x=image_preprocess(style_image,image_shape,)
    _,style_y=extract_features(style_x,content_layers,style_layers)
    style_y=[gram(y) for y in style_y]
    return style_x,style_y

def get_content(content_image,image_shape):
    content_x=image_preprocess(content_image,image_shape)
    content_y,_ = extract_features(content_x, content_layers, style_layers)

    return content_x,content_y

image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)

style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成图像x,将合成图片的初始值设为(样式)内容图片来加速收敛
x.attach_grad()

def train(x,max_epochs,lr,lr_decay_epoch=200):
    for i in range(max_epochs):
        with ag.record():
            content_py,style_py=extract_features(x,content_layers,style_layers)

            content_L=sum_loss(content_loss,content_py,content_y,content_weight)
            style_L=sum_loss(style_loss,style_py,style_y,style_weights)
            tv_L=tv_loss(x)*tv_weight
            loss=content_L+style_L+tv_L
        loss.backward()
        x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它对lr没那么敏感
        x[:]-=lr*x.grad
        nd.waitall()
        if i and i%5==0:
            print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
                i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
            ))
        if i and i%lr_decay_epoch==0:
            lr*=0.1

y=train(x,1000,0.1)

训练结果:
在这里插入图片描述

7、显示合成的图片

y=nd.squeeze(y,0).as_in_context(mx.cpu())

result=postprocess(y) # 显示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()

让我们来看看最终结果:
样式图片:
在这里插入图片描述
内容图片:
在这里插入图片描述
合成的图片:
在这里插入图片描述
        是不是有那么点味道了?要想让它更能学习到样式和内容图片的特征,并且加速收敛,可以把初始化的x(合成图片)设为上面生成的合成图片,让上面生成的合成图片去做权值更新。

下面放上所有代码:

import mxnet.ndarray as nd
import mxnet as mx
import mxnet.gluon as gn
import mxnet.autograd as ag
from mxnet.gluon.model_zoo import vision as models
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
ctx=mx.gpu()
# 1、数据
style_image=cv.imread("style_content/style.jpeg")
content_image=cv.imread("style_content/content.jpg")

# 2、数据预处理
rgb_mean = nd.array([0.485, 0.456, 0.406])
rgb_std = nd.array([0.229, 0.224, 0.225])


def image_preprocess(image, input_shape):
    image = cv.resize(image, (input_shape[1],input_shape[0]))
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    img = (nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std

    return img.transpose((2, 0, 1)).expand_dims(axis=0).as_in_context(ctx)


def postprocess(image):# 后处理图像
    new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
    new_image=(new_image*255).asnumpy().astype(np.uint8)
    return new_image


# 3、模型
pretrained_net = models.vgg19(pretrained=True,ctx=ctx)
print(pretrained_net)

style_layers = [0, 5, 10, 19, 28]
content_layers = [25]


# 4、构建新网络
def get_net(pretrained_net, content_layers, style_layers):
    net = gn.nn.Sequential()
    for i in range(max(content_layers + style_layers) + 1):
        # print(i)
        net.add(pretrained_net.features[i])
    return net


net = get_net(pretrained_net, content_layers, style_layers)
net.collect_params().reset_ctx(ctx=ctx)

def extract_features(x, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        x = net[i](x)
        if i in style_layers:
            styles.append(x)
        if i in content_layers:
            contents.append(x)
    return contents, styles

# 5、损失函数
# 1)内容匹配
def content_loss(yhat, y):
    return (yhat - y).square().mean()

# 2)样式匹配,先定义Gram矩阵
def gram(x):
    # 根据公式,获得channel
    c, h, w = x.shape[1], x.shape[2], x.shape[3]
    hw = h * w
    y = x.reshape((c, int(hw)))
    return nd.dot(y, y.T) / hw # 除hw为了normliza一下

def style_loss(yhat,gram_y):
    return (gram(yhat)-gram_y).square().mean()

def tv_loss(yhat):
    return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
                  (yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())

channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel数
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10

def sum_loss(loss,preds,truths,weights):
    return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])

# 6、训练
# 定义两个函数,分别对样式图片和内容图片提取特征
def get_style(style_image,image_shape):
    style_x=image_preprocess(style_image,image_shape,)
    _,style_y=extract_features(style_x,content_layers,style_layers)
    style_y=[gram(y) for y in style_y]
    return style_x,style_y

def get_content(content_image,image_shape):
    content_x=image_preprocess(content_image,image_shape)
    content_y,_ = extract_features(content_x, content_layers, style_layers)

    return content_x,content_y

image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)

style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成图像x,将合成图片的初始值设为(样式)内容图片来加速收敛
x.attach_grad()

def train(x,max_epochs,lr,lr_decay_epoch=200):
    for i in range(max_epochs):
        with ag.record():
            content_py,style_py=extract_features(x,content_layers,style_layers)

            content_L=sum_loss(content_loss,content_py,content_y,content_weight)
            style_L=sum_loss(style_loss,style_py,style_y,style_weights)
            tv_L=tv_loss(x)*tv_weight
            loss=content_L+style_L+tv_L
        loss.backward()
        x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它对lr没那么敏感
        x[:]-=lr*x.grad
        nd.waitall()
        if i and i%5==0:
            print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
                i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
            ))
        if i and i%lr_decay_epoch==0:
            lr*=0.1
    return x


y=train(x,1000,0.1)
# 7、显示合成图像
y=nd.squeeze(y,0).as_in_context(mx.cpu())

result=postprocess(y) # 显示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章