PyTorch進行神經風格轉換/遷移(Neural-Transfer:圖像風格遷移)

前言

在這裏插入圖片描述

1.介紹

本教程主要講解如何實現由Leon A. Gatys,Alexander S. Ecker和Matthias Bethge提出的 Neural-Style 算法。Neural-Style或者叫Neural-Transfer,可以讓你運用新的風格將你指定的圖片進行重構。這個算法將使用兩張圖片,一張圖片作爲風格提供者,一張圖片作爲內容提供者,另外生成一張圖片內容與內容圖片相似,但圖片風格的和風格圖片相似的新圖片。

2. 基本原理

其實現原理非常特別,展現了人類思維的巧妙性,
我們定義兩個優化指標:
1,一個用於內容D_C;
2,一個用於風格D_S。
D_C度量兩張圖片內容上的區別,而D_S用來測量兩張圖片風格的區別。
然後,我們生成第三張圖片,並優化這張圖片,使其與內容圖片的內容差別和風格圖片的風格差別最小化。
現在,原理講完了,開始實現吧,首先,我們導入必要的包。

3 準備工作

首先是導入以下的包:

  1. torch, torch.nn(使用PyTorch進行風格轉換必不可少的包)
  2. numpy (矩陣處理必須用)
  3. torch.optim (高效的梯度下降)
  4. PIL, PIL.Image, matplotlib.pyplot (加載和展示圖片)
  5. torchvision.transforms (將PIL圖片轉換成張量)
  6. torchvision.models (訓練或加載預訓練模型)
  7. copy (對模型進行深度拷貝;系統包)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy
import warnings
warnings.filterwarnings("ignore")

接着是一個比較關鍵的步驟,確定是GPU還是CPU來運行神經網絡。雖然,在GPU上運行可以加速,但有的電腦上沒有GPU。
我們可以使用torch.cuda.is_available()來判斷是否有可用的GPU。

if torch,cuda.is_available():
	device=torch.device("cuda")
else:
	device=torch.device("cpu")

4 加載素材

接下來是導入提供風格和內容的圖片。
導入圖片也意味着對圖片進行預處理,原始的PIL圖片的屬性值介於0到255之間,但是當轉換成torch張量時,它們的值被壓縮到0到1之間,另外,圖片分辨率也會被調整到520。一個重要的細節是,注意torch庫中的神經網絡用來訓練的張量的值爲0到1之間。如果你嘗試將0到255的張量圖片加載到神經網絡,然後激活的特徵映射將不能偵測到目標內容和風格。然而,Caffe庫中的預訓練網絡用來訓練的張量值爲0到255之間的圖片。

這是一個下載本教程需要用到的圖片的鏈接: picasso.jpgdancing.jpg。下載這兩張圖片並且將它們添加到你當前工作目錄。
如果嫌麻煩,可以關注公衆號,
在這裏插入圖片描述

發送neuralstyle,自動推送資源集成包。
設置圖片預處理程序

# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128  # use small size if no gpu
 
loader = transforms.Compose([
    transforms.Resize(imsize),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor
 
def image_loader(image_name):
    image = Image.open(image_name)
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0) #添加一個0維度 batch 適應網絡輸入
    return image.to(device, torch.float)
 
style_img = image_loader("picasso.jpg")
content_img = image_loader("dancing.jpg")
 
assert style_img.size() == content_img.size(), \
    "we need to import style and content images of the same size"

因爲tensor是四維的不能直接展示,所以,我們創建一個imshow函數,重新將圖片轉換成標準三維數據來展示,這樣也可以讓我們確認圖像是否被正確加載。

unloader = transforms.ToPILImage()  # reconvert into PIL image
 
plt.ion()
 
def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension 去掉0維度
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated
 
plt.figure()
imshow(style_img, title='Style Image')
 
plt.figure()
imshow(content_img, title='Content Image')

正確加載的話,運行到這裏,可以看到這個。
在這裏插入圖片描述
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章