風格遷移–生成你想要的風格
標籤: pytorch
隨着深度網絡的流行,用AI作畫也不再是問題,比如下面這一張:
你能看出來是手畫的,還是自動生成的嗎。
下面介紹一個風格遷移網絡,能夠幫你生成任意你想要的style。本文也會提供一個Starry_Night_Over_the_Rhone
的style模型,大家可以自己後臺回覆style_transform
獲取代碼和模型。
下面簡單介紹一下風格遷移網絡。
網絡結構
上圖就是快速風格遷移網絡的結構,左邊虛線框裏面是一個Encoder-Decoder
結構,而右邊整個就是一個訓練好的vgg,主要用來做特徵提取進而能夠計算圖片間的損失。
從圖中可以看出輸入是一個x
,經過Image Transform Net
會變爲一個y^
,而這個y^
就是我們要的圖片,也就是經過風格轉換後的圖片。比如我們輸入一張東方明珠電視塔圖片作爲x
,那麼文章剛開始的那個圖片就是作爲y^
,那這個y^
是如何學習得到的呢。主要靠後面vgg網絡做損失,然後指導前面的Image Transform Net
學習。
下面介紹一下這個網絡中最重要的損失函數,這個損失函數不同於之前的分類網絡的損失,原來的分類網絡一般就是一個交叉熵函數,但是這裏的損失是一個預訓練好的vgg,從圖中可以看出,Loss Network
有三個輸入,分別是ys y^ yc
,其中ys
就是風格圖片,在本次實驗中我們選擇的是:
正是由於ys
的緣故,所以我們的y^
在風格上和它非常像。而yc
其實就是x
。將這三個輸入到vgg裏面,然後計算利用vgg強大的特徵提取能力,把提取的特徵做爲損失,我們的目的是使得我們y^
在內容上和yc
相近,而風格上和ys
更近,所以引出了兩類損失,第一類是風格損失,第二類是內容損失,對應圖中內容損失就直接對y^ yc
的中間特徵用mse計算即可,也就是右邊下面的那個損失,而風格損失是上面的三個,是對ys y^
的中間特徵計算gram得到。
最後我們優化這兩個損失就能保證我們的輸出y^
在風格和ys
更近,而內容上和yc
更近。
代碼簡析
這部分對代碼做一個簡單的分析,其中main.py
是主函數,裏面包含了兩個主要方法train、stylize
,其中train是用來訓練模型的, 如果你有充分的數據集,你可以自己加載數據來進行訓練,只需要修改Config裏面data_root
即可。
train裏面主要就是加載數據,加載模型TransformerNet
,而TransformerNet
就是前面說的那個Image Transform Net
,損失網絡同樣使用的是Vgg
,在訓練的過程中只更新TransformerNet
的參數,因爲Vgg
是作爲一個損失函數來用的,它直接使用一個ImageNet的預訓練參數即可。
而stylize函數則提供了一個測試,當我們訓練好了模型,就可以用這個函數來幫我們生成圖片了,我們在Config裏面指定一個content_path
,這裏我們可以假定是一個東方明珠,你可以用其他代替。在stylize裏面做的事情就是把TransformerNet
加載一下,注意要把訓練好的模型給加載上去,然後一次前向傳播即可。
如果你想直接用
如果你不太懂,而想嘗下鮮,那麼下載完代碼後,直接輸入python main.py
即可,如果想換個圖片,直接修改第40行的content_path = 'images.jpeg' # 需要進行分割遷移的圖片
即可。
歡迎大家關注我的微信公衆號,未來上面會推送python
機器學習
算法學習
深度學習
論文閱讀
以及偶爾的小雞湯
等內容。ようこそいらっしゃい!
搜索 coderwangson 關注