風格遷移--生成你想要的風格

風格遷移–生成你想要的風格

標籤: pytorch


隨着深度網絡的流行,用AI作畫也不再是問題,比如下面這一張:

output.png

你能看出來是手畫的,還是自動生成的嗎。

下面介紹一個風格遷移網絡,能夠幫你生成任意你想要的style。本文也會提供一個Starry_Night_Over_the_Rhone的style模型,大家可以自己後臺回覆style_transform獲取代碼和模型。

下面簡單介紹一下風格遷移網絡。

網絡結構

網絡結構.jpg

上圖就是快速風格遷移網絡的結構,左邊虛線框裏面是一個Encoder-Decoder結構,而右邊整個就是一個訓練好的vgg,主要用來做特徵提取進而能夠計算圖片間的損失。

從圖中可以看出輸入是一個x,經過Image Transform Net會變爲一個y^,而這個y^就是我們要的圖片,也就是經過風格轉換後的圖片。比如我們輸入一張東方明珠電視塔圖片作爲x,那麼文章剛開始的那個圖片就是作爲y^,那這個y^是如何學習得到的呢。主要靠後面vgg網絡做損失,然後指導前面的Image Transform Net學習。

下面介紹一下這個網絡中最重要的損失函數,這個損失函數不同於之前的分類網絡的損失,原來的分類網絡一般就是一個交叉熵函數,但是這裏的損失是一個預訓練好的vgg,從圖中可以看出,Loss Network有三個輸入,分別是ys y^ yc,其中ys就是風格圖片,在本次實驗中我們選擇的是:

Starry_Night_Over_the_Rhone.jpg

正是由於ys的緣故,所以我們的y^在風格上和它非常像。而yc其實就是x。將這三個輸入到vgg裏面,然後計算利用vgg強大的特徵提取能力,把提取的特徵做爲損失,我們的目的是使得我們y^在內容上和yc相近,而風格上和ys更近,所以引出了兩類損失,第一類是風格損失,第二類是內容損失,對應圖中內容損失就直接對y^ yc的中間特徵用mse計算即可,也就是右邊下面的那個損失,而風格損失是上面的三個,是對ys y^的中間特徵計算gram得到。

最後我們優化這兩個損失就能保證我們的輸出y^在風格和ys更近,而內容上和yc更近。

代碼簡析

這部分對代碼做一個簡單的分析,其中main.py是主函數,裏面包含了兩個主要方法train、stylize,其中train是用來訓練模型的, 如果你有充分的數據集,你可以自己加載數據來進行訓練,只需要修改Config裏面data_root即可。

train裏面主要就是加載數據,加載模型TransformerNet,而TransformerNet就是前面說的那個Image Transform Net,損失網絡同樣使用的是Vgg,在訓練的過程中只更新TransformerNet的參數,因爲Vgg是作爲一個損失函數來用的,它直接使用一個ImageNet的預訓練參數即可。

而stylize函數則提供了一個測試,當我們訓練好了模型,就可以用這個函數來幫我們生成圖片了,我們在Config裏面指定一個content_path,這裏我們可以假定是一個東方明珠,你可以用其他代替。在stylize裏面做的事情就是把TransformerNet加載一下,注意要把訓練好的模型給加載上去,然後一次前向傳播即可。

如果你想直接用

如果你不太懂,而想嘗下鮮,那麼下載完代碼後,直接輸入python main.py即可,如果想換個圖片,直接修改第40行的content_path = 'images.jpeg' # 需要進行分割遷移的圖片即可。

歡迎大家關注我的微信公衆號,未來上面會推送python 機器學習 算法學習 深度學習 論文閱讀 以及偶爾的小雞湯等內容。ようこそいらっしゃい!

搜索 coderwangson 關注

image

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章