PyTorch ------圖像風格遷移學習原理

圖像風格遷移學習原理

圖像風格遷移學習介紹

  • 利用算法將一張圖片的風格樣式,應用到另一張圖畫上的技術亦可以稱爲Neural-Style或者Neural-Transfer
  • 該算法獲取三張圖片,即輸入圖片、內容圖片和樣式圖片,然後更改輸入以使其類似於內容圖像的內容和樣式圖像的風格
  • 在這裏插入圖片描述

基本原理

  • 原理很簡單我們定義兩個表示距離的變量,一個表示輸入圖片和內容圖片的距離(Dc),一個表示輸入圖片和樣式圖片的距離(Ds).即Dc測量輸入和內容圖片的內容差異的距離,Ds則測量輸入和樣式圖片之間樣式的差異距離.
  • 最後我們將優化Dc和Ds使之最小,即完成圖像風格轉移

相關知識

Gram matrix

  • Gram矩陣和協方差矩陣相似,差異在於Gram矩陣沒有白化,直接使用兩變量做內積
  • Gram矩陣和相關係數矩陣葉相似,差異在於,沒有白化,也沒有標準化
  • 總結上面說來就是Gram 矩陣相對於協方差矩陣和相關關係矩陣來說比較粗糙簡單,但亦能表達其意思.
  • 不瞭解協方差和相關關係的同學可以參考傳送門,知乎贊最多的一篇文章👍👍👍👍👍👍👍👍👍
  • 在這裏插入圖片描述

VGG

  • 提取圖像風格和圖像內容的圖像是VGG19神經網絡模型
  • 這個可以參考上一片文章傳送門👏👏👏👏👏👏👏👏👏
  • 對於VGG模型一般來說,越靠近輸入層的卷積層輸出越容易抽取圖像的細節信息例如淺層的conv1_1,conv1_2,提取的特徵通常是比較簡單的線,角,靠近輸出的卷積層輸出的是全局的信息,特徵比較複雜也可以認爲是整體的信息
  • 下圖爲VGG19的特徵提取的結構
  • 在這裏插入圖片描述

下面上代碼時間

import time import torch import torch.nn.functional as F import torchvision import numpy as np import matplotlib.pyplot as plt from PIL import Image import cv2 as cv import os import sys import platform #檢測設備 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) print(torch.__version__) Base_file_path = "./dataset/" content_file_pathpng = "rainier.png" style_file_pathpng = "autumn_oak.png" content_file_path = "rainier.jpg" style_file_path = "autumn_oak.jpg" content_image = Image.open(os.path.join(Base_file_path,content_file_path)) plt.imshow(content_image) plt.show() style_image = Image.open(os.path.join(Base_file_path,style_file_path)) plt.imshow(style_image) rgb_mean = np.array([0.485,0.456,0.406]) rgb_std = np.array([0.229,0.224,0.225]) def preprocess(PIL_image,image_shape): process = torchvision.transforms.Compose([ torchvision.transforms.Resize(size = image_shape), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=rgb_mean,std=rgb_std) ]) return process(PIL_image).unsqueeze(dim=0) def postprocess(img_tensor): inv_normalize = torchvision.transforms.Normalize( mean=-rgb_mean/rgb_std, std=1/rgb_std ) to_PIL_image = torchvision.transforms.ToPILImage() return to_PIL_image(inv_normalize(img_tensor[0].cpu()).clamp(0,1)) #VGG19 pretrained_net = torchvision.models.vgg19(pretrained=True) print(pretrained_net) """ 爲了抽取圖像的內容特徵和樣式特徵,我們可以選擇VGG網絡中某些層的輸出。 一般來說,越靠近輸入層的輸出越容易抽取圖像的細節信息,反之則越容易抽取圖像的全局信息。 爲了避免合成圖像過多的保留內容圖像的細節,我們選擇VGG較靠近輸出的層,也稱爲內容層,來輸出圖像的內容特徵。 我們還從VGG中選擇不同層的輸出來匹配局部和全局的樣式,這些層也叫樣式層。 """ #style layers 每個Block的第一個卷積層 """ 指定的特徵層 可以優化 或許其他層的 提取樣式 內容 會更好 也可以 更換模型 來對比 提取的樣式內容 """ style_layers,content_layers = [0,5,10,19,28],[25] #提取特徵 net_list = [] for i in range(max(content_layers + style_layers) + 1): # 將 預訓練 模型的 指定的層的特徵 提取出來 net_list.append(pretrained_net.features[i]) #重新組成一個模型 net = torch.nn.Sequential(*net_list) """ 給定輸入X,如果簡單調用前向計算net(X),只能獲得最後一層的輸出。 由於我們還需要中間層的輸出,因此這裏我們逐層計算,並保留內容層和樣式層的輸出。 """ def extract_features(X,content_layers,style_layers): contents = [] styles = [] for i in range(len(net)): X = net[i](X) if i in style_layers: styles.append(X) if i in content_layers: contents.append(X) return contents,styles #將內容圖片 的內容特徵 提取出來 def get_contents(image_shape,device): #將內容圖片 經過處理後的矩陣 content_X = preprocess(content_image,image_shape).to(device) #將內容 圖片的 內容特徵 提取出來 contents_Y ,_ = extract_features(content_X,content_layers,style_layers) return content_X,contents_Y #將樣式圖片 的樣式特徵 提取出來 def get_styles(image_shape,device): style_X = preprocess(style_image,image_shape).to(device) # 將樣式圖片 的樣式特徵 提取出來 _,styles_Y = extract_features(style_X,content_layers,style_layers) return style_X,styles_Y tmp_content, contents_Y = get_contents(1360, device) result = postprocess(tmp_content) plt.imshow(result) plt.show() #定義損失函數 """ 樣式遷移的損失函數,它由內容損失,樣式損失和總變差損失3部分組成 """ """ 內容損失 與線性迴歸中的損失函數類似,內容損失通過平方差誤差函數衡量合成圖像與內容特徵上的差異。 平方誤差函數的兩個輸入均爲 extract——features 函數計算所得到的內容層的輸出 """ #內容損失函數 比較的是 合成圖像內容 和 提供 內容的圖像的loss function def content_loss(Y_hat,Y): return F.mse_loss(Y_hat,Y) """ 樣式損失 樣式損失也一樣通過平方差誤差函數衡量合成圖像與樣式圖像在樣式上的差異,爲了表達樣式層輸出的樣式, 我們先通過extract-features函數計算樣式層的輸出。假設該輸出的樣本數爲1,通道數爲C 高和寬分別爲h和w, 我們可以把輸出變換成c行hw列的矩陣X。矩陣X可以看作是由C個長度爲hw的向量X1,。。。Xc """ """ gram metrax 體現的是圖片中 各個特徵通道的相關性 """ def gram(X): num_channels,n = X.shape[1],X.shape[2] * X.shape[3] X = X.view(num_channels,n) return torch.matmul(X,X.t())/(num_channels*n) #樣式的損失函數 比較的 是 合成圖像的樣式 和樣式圖片提供的樣式的 loss function def style_loss(Y_hat,gram_Y): return F.mse_loss(gram(Y_hat),gram_Y) #總變差 損失 """ 合成圖像裏面有大量的高頻噪點,即有特別亮或者特別暗的顆粒像素 。 一種常用的降噪方法是總變差降噪 降低總變差損失 """ def tv_loss(Y_hat): return 0.5 *(F.l1_loss(Y_hat[:,:,1:,:],Y_hat[:,:,:-1,:]) + F.l1_loss(Y_hat[:,:,:,1:],Y_hat[:,:,:,:-1])) content_weight,style_weight,tv_weight = 1,1e3,10 """ 樣式遷移的損失函數 即 內容損失、樣式損失和總變差損失函數的加權和 通過調節這些權值超參數我們可以權衡合成圖像在保留內容、遷移樣式以及降噪三方面的相對重要性 """ def compute_loss(X,contents_Y_hat,styles_Y_hat,contents_Y,styles_Y_gram): # 計算內容損失 contents_l = [content_loss(Y_hat,Y) * content_weight for Y_hat,Y in zip(contents_Y_hat,contents_Y)] #計算 樣式損失 styles_l = [style_loss(Y_hat,Y) * style_weight for Y_hat,Y in zip(styles_Y_hat,styles_Y_gram)] # 計算 總變差損失 tv_l = tv_loss(X) * tv_weight l = sum(styles_l) + sum(contents_l) + tv_l return contents_l,styles_l,tv_l,l class GeneratedImage(torch.nn.Module): def __init__(self,image_shape): super(GeneratedImage,self).__init__() self.weight = torch.nn.Parameter(torch.rand(* image_shape)) def forward(self): print("into here forward Generate Image") return self.weight def get_inits(X,device,lr,styles_Y): gen_image = GeneratedImage(X.shape).to(device) gen_image.weight.data = X.data optimizer = torch.optim.Adam(gen_image.parameters(),lr = lr) styles_Y_gram = [gram(Y) for Y in styles_Y] return gen_image(),styles_Y_gram,optimizer def train(X, contents_Y, styles_Y, device, lr, max_epochs, lr_decay_epoch): print("training on ", device) X, styles_Y_gram, optimizer = get_inits(X, device, lr, styles_Y) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_decay_epoch, gamma=0.1) for i in range(max_epochs): start = time.time() print( " epoch ",i) XCopy = np.array(X.data) print("equal metrix",(XCopy == np.array(X.data)).all()) contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers) contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram) print("zero grad....") optimizer.zero_grad() l.backward(retain_graph=True) print("backward....") optimizer.step() print("update ....") print("equal metrix", (XCopy == np.array(X.data)).all()) scheduler.step() if i % 50 == 0 and i != 0: print('epoch %3d, content loss %.2f, style loss %.2f, ' 'TV loss %.2f, %.2f sec' % (i, sum(contents_l).item(), sum(styles_l).item(), tv_l.item(), time.time() - start)) return X.detach()

#創建合成 圖片的尺寸
image_shape = (150, 224)
#將模型 轉化爲 當前設備的 數據類型
net = net.to(device)
#contentX 合成圖片的內容 載體 將 內容圖片大小 內容設置爲 合成圖片尺寸
#contentY 是將 內容圖片的 特徵 提取出來
content_X, contents_Y = get_contents(image_shape, device)
#style_X 爲合成圖片的 樣式載體 將 樣式圖片大小 內容 設置爲 合成圖片大小
#content_Y 爲將樣式圖片的 特徵 提取出來
style_X, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.01, 50, 20)
print(“well done”)

解答階段

  • 最後解答一下同學的問題:
  • 有同學問:最後輸出圖片尺寸大於內容圖片尺寸,那最後輸出是不是不準確
  • 答:在開始的時候,對圖片做了Resize根據設置的大小對圖片大小做處理,大小設置太大,圖片可能會失真
  • 同學問:gram matrix 爲什麼要歸一化
  • 答歸一化的原因 ATA內積產生的數值過大這些較大的值將導致第一層在梯度下降期間具有較大的影響,可以使模型更深,所以歸一化至關重要
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章