如果你喜歡拍照,可能接觸過濾鏡(如下圖)。它能改變照片的顏色樣式,從而使風景照更加銳利或者令人像更加美白。但一個濾鏡通常只能改變照片的某個方面。如果要照片達到理想中的樣式,經常需要嘗試大量不同的組合,其複雜程度不亞於模型調參。
具體的原理本章就不做介紹了,請讀者自行百度或閱讀相關論文。
基本結構
上圖中,分別使用CNN對樣式圖片和內容圖片抽取特徵,比如,通過第1、2、4層卷積抽取樣式圖片特徵,通過第3層卷積抽取內容圖片特徵,然後通過正向傳播計算樣式遷移的損失函數,通過反向傳播迭代模型參數,即不斷更新合成圖像。樣式遷移常用的損失函數由3部分組成:內容損失(content loss)使合成圖像與內容圖像在內容特徵上接近,樣式損失(style loss)令合成圖像與樣式圖像在樣式特徵上接近,而總變差損失(total variation loss)則有助於減少合成圖像中的噪點。最後,當模型訓練結束時,我們輸出樣式遷移的模型參數,即得到最終的合成圖像。一般步驟如下:
1)挑選網絡特定層分別作爲樣式層和內容層;
2)輸入樣式圖片並保存樣式層輸出,記爲;
3) 輸入內容圖片並保存內容層輸出,記爲;
4) 初始化合成圖片X爲隨機值,然後進行迭代,使得用抽取的特徵能夠匹配上和,具體來說,我們如下迭代直至收斂:
a) 輸入計算樣式層和內容層輸出,記層輸出爲;
b) 使用樣式損失函數計算和的差異;
c) 使用內容損失函數計算和的差異;
d) 對損失求和並對求導,記導數爲;
e) 更新,例如。
內容損失函數通常使用迴歸的均方誤差,對於樣式,我們通常將它看成是像素點在每個通道的統計分佈。例如要匹配兩張圖片的顏色,我們的一個做法是匹配這兩張圖片在RGB三個通道的直方圖,更一般的,假設卷積輸出的格式是,那麼我們可以把它變形成一個的2D矩陣,並將它看成是一個維度爲的隨機變量採樣到的個點。所謂的樣式匹配就是使得兩個維隨機變量統計分佈一致。
匹配統計分佈常用的做法是衝量匹配,就是說使它們有一樣的均值、協方差和其它高維的衝量。爲了計算簡單起見,我們這裏假設卷積輸出已經是均值爲0了,而且,我們只匹配協方差,也就是說,樣式損失函數就是對和計算Gram矩陣然後應用均方誤差:
這裏假設我們已經將和變形成了的2D矩陣了。
下面我們將實現這個算法來深入理解各個參數,例如樣式層和內容層的選取,對實際結果的影響。
1、數據
樣式圖片:
內容圖片:
2、數據處理
rgb_mean=nd.array([0.485,0.456,0.406])
rgb_std=nd.array([0.229,0.224,0.225])
def image_preprocess(image,input_shape):
image=cv.resize(image,*input_shape)
img=(nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std
return img.transpose((2,0,1)).expand_dims(axis=0)
def postprocess(image):# 後處理圖像
new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
new_image=(new_image*255).asnumpy().astype(np.uint8)
return new_image
3、模型
本章我們使用的是VGG-19的預訓練模型,它已經在ImageNet上做過預訓練了。
pretrained_net = models.vgg19(pretrained=True)
print(pretrained_net)
模型結構:
從結構上來看,VGG-19有5個卷積塊(5個池化),這裏我們按照原論文的每個卷積塊的第一個卷積輸出來匹配樣式,第四個卷積塊的最後一個卷積輸出來匹配內容。根據打印的結果我們來記錄下位置:
style_layers=[0,5,10,19,28]
content_layers=[25]
4、構建新網絡
因爲只需要用到中間層的輸出,所以需要構建新的網絡,它只保留需要用到的VGG的所有層:
def get_net(pretrained_net,content_layers,style_layers):
net=gn.nn.Sequential()
for i in range(max(content_layers+style_layers)+1):
# print(i)
net.add(pretrained_net.features[i])
return net
net=get_net(pretrained_net,content_layers,style_layers)
給定輸入x,我們能拿到最後一層的輸出,現在我們需要拿到指定層的輸出:
def extract_features(x,content_layers,style_layers):
contents=[]
styles=[]
for i in range(len(net)):
x=net[i](x)
if i in style_layers:
styles.append(x)
if i in content_layers:
contents.append(x)
return contents, styles
5、損失函數
內容匹配是一個迴歸問題,這裏使用均方誤差來比較內容層的輸出:
1) 內容匹配損失函數
def content_loss(yhat,y):
return (yhat-y).square().mean()
2)樣式匹配損失函數,通過擬合Gram矩陣
def gram(x):
# 根據公式,獲得channel
c, h, w = x.shape[1], x.shape[2], x.shape[3]
hw = h * w
y = x.reshape((c, int(hw)))
return nd.dot(y, y.T) / hw # 除hw爲了normliza一下
3)總變差降噪
當使用輸出層的高層輸出來擬合時,會發現學到的圖片中含有大量噪聲,對於解決噪聲,可以利用模糊濾鏡(中值濾波、高斯濾波等,但會使圖片邊緣信息也模糊),或者總變差降噪。
假設表示每個像素,那麼加入下面的損失函數,它使得鄰近的像素值相似:
所以,根據公式有:
def tv_loss(yhat):
return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
(yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())
總損失函數是上述三個損失函數的加權和,通過調整權重值,我們可以控制學到的圖片是否保留更多樣式、內容。注意到樣式匹配中我們使用了5個層的輸出,所以所以應該對靠近輸入層的層給予更大的權重。
channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel數
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10
def sum_loss(loss,preds,truths,weights):
return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])
6、訓練
# 6、訓練
# 定義兩個函數,分別對樣式圖片和內容圖片提取特徵
def get_style(style_image,image_shape):
style_x=image_preprocess(style_image,image_shape,)
_,style_y=extract_features(style_x,content_layers,style_layers)
style_y=[gram(y) for y in style_y]
return style_x,style_y
def get_content(content_image,image_shape):
content_x=image_preprocess(content_image,image_shape)
content_y,_ = extract_features(content_x, content_layers, style_layers)
return content_x,content_y
image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)
style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成圖像x,將合成圖片的初始值設爲(樣式)內容圖片來加速收斂
x.attach_grad()
def train(x,max_epochs,lr,lr_decay_epoch=200):
for i in range(max_epochs):
with ag.record():
content_py,style_py=extract_features(x,content_layers,style_layers)
content_L=sum_loss(content_loss,content_py,content_y,content_weight)
style_L=sum_loss(style_loss,style_py,style_y,style_weights)
tv_L=tv_loss(x)*tv_weight
loss=content_L+style_L+tv_L
loss.backward()
x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它對lr沒那麼敏感
x[:]-=lr*x.grad
nd.waitall()
if i and i%5==0:
print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
))
if i and i%lr_decay_epoch==0:
lr*=0.1
y=train(x,1000,0.1)
訓練結果:
7、顯示合成的圖片
y=nd.squeeze(y,0).as_in_context(mx.cpu())
result=postprocess(y) # 顯示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()
讓我們來看看最終結果:
樣式圖片:
內容圖片:
合成的圖片:
是不是有那麼點味道了?要想讓它更能學習到樣式和內容圖片的特徵,並且加速收斂,可以把初始化的x(合成圖片)設爲上面生成的合成圖片,讓上面生成的合成圖片去做權值更新。
下面放上所有代碼:
import mxnet.ndarray as nd
import mxnet as mx
import mxnet.gluon as gn
import mxnet.autograd as ag
from mxnet.gluon.model_zoo import vision as models
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
ctx=mx.gpu()
# 1、數據
style_image=cv.imread("style_content/style.jpeg")
content_image=cv.imread("style_content/content.jpg")
# 2、數據預處理
rgb_mean = nd.array([0.485, 0.456, 0.406])
rgb_std = nd.array([0.229, 0.224, 0.225])
def image_preprocess(image, input_shape):
image = cv.resize(image, (input_shape[1],input_shape[0]))
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
img = (nd.array(image).astype("float32") / 255.0 - rgb_mean) / rgb_std
return img.transpose((2, 0, 1)).expand_dims(axis=0).as_in_context(ctx)
def postprocess(image):# 後處理圖像
new_image=(image.transpose((1, 2, 0)) * rgb_std + rgb_mean).clip(0, 1)
new_image=(new_image*255).asnumpy().astype(np.uint8)
return new_image
# 3、模型
pretrained_net = models.vgg19(pretrained=True,ctx=ctx)
print(pretrained_net)
style_layers = [0, 5, 10, 19, 28]
content_layers = [25]
# 4、構建新網絡
def get_net(pretrained_net, content_layers, style_layers):
net = gn.nn.Sequential()
for i in range(max(content_layers + style_layers) + 1):
# print(i)
net.add(pretrained_net.features[i])
return net
net = get_net(pretrained_net, content_layers, style_layers)
net.collect_params().reset_ctx(ctx=ctx)
def extract_features(x, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
x = net[i](x)
if i in style_layers:
styles.append(x)
if i in content_layers:
contents.append(x)
return contents, styles
# 5、損失函數
# 1)內容匹配
def content_loss(yhat, y):
return (yhat - y).square().mean()
# 2)樣式匹配,先定義Gram矩陣
def gram(x):
# 根據公式,獲得channel
c, h, w = x.shape[1], x.shape[2], x.shape[3]
hw = h * w
y = x.reshape((c, int(hw)))
return nd.dot(y, y.T) / hw # 除hw爲了normliza一下
def style_loss(yhat,gram_y):
return (gram(yhat)-gram_y).square().mean()
def tv_loss(yhat):
return 0.5 * ((yhat[:, :, 1:, :] - yhat[:, :, :-1, :]).abs().mean() +
(yhat[:, :, :, 1:] - yhat[:, :, :, :-1]).abs().mean())
channels=[net[l].weight.shape[0] for l in style_layers] #拿出channel數
style_weights=[1e4/n**2 for n in channels]
content_weight=[1]
tv_weight=10
def sum_loss(loss,preds,truths,weights):
return nd.add_n(*[w*loss(yhat,y) for w,yhat,y in zip(weights,preds,truths)])
# 6、訓練
# 定義兩個函數,分別對樣式圖片和內容圖片提取特徵
def get_style(style_image,image_shape):
style_x=image_preprocess(style_image,image_shape,)
_,style_y=extract_features(style_x,content_layers,style_layers)
style_y=[gram(y) for y in style_y]
return style_x,style_y
def get_content(content_image,image_shape):
content_x=image_preprocess(content_image,image_shape)
content_y,_ = extract_features(content_x, content_layers, style_layers)
return content_x,content_y
image_shape=(200,300)
content_x,content_y=get_content(content_image,image_shape)
style_x,style_y=get_style(style_image,image_shape)
x=content_x.copyto(ctx) # 初始化合成圖像x,將合成圖片的初始值設爲(樣式)內容圖片來加速收斂
x.attach_grad()
def train(x,max_epochs,lr,lr_decay_epoch=200):
for i in range(max_epochs):
with ag.record():
content_py,style_py=extract_features(x,content_layers,style_layers)
content_L=sum_loss(content_loss,content_py,content_y,content_weight)
style_L=sum_loss(style_loss,style_py,style_y,style_weights)
tv_L=tv_loss(x)*tv_weight
loss=content_L+style_L+tv_L
loss.backward()
x.grad[:]/=x.grad.abs().mean()+1e-8 #使得它對lr沒那麼敏感
x[:]-=lr*x.grad
nd.waitall()
if i and i%5==0:
print("batch %2d,content %.2f,style %.2f,TV %.2f"%(
i,content_L.asscalar(),style_L.asscalar(),tv_L.asscalar()
))
if i and i%lr_decay_epoch==0:
lr*=0.1
return x
y=train(x,1000,0.1)
# 7、顯示合成圖像
y=nd.squeeze(y,0).as_in_context(mx.cpu())
result=postprocess(y) # 顯示合成的x
result=cv.cvtColor(result,cv.COLOR_RGB2BGR)
cv.imshow("fix image",result)
cv.waitKey(0)
cv.destroyAllWindows()