使用自動編解碼器網絡實現圖片噪音去除

在前面章節中,我們一再看到,訓練或使用神經網絡進行預測時,我們需要把數據轉換成張量。例如要把圖片輸入卷積網絡,我們需要把圖片轉換成二維張量,如果要把句子輸入LSTM網絡,我們需要把句子中的單詞轉換成one-hot-encoding向量。

這種數據類型轉換往往是由人設計的,我們本節介紹一種神經網絡,它能夠爲輸入數據自動找到合適的數據轉換方法,它自動把數據轉換成某種格式的張量,然後又能把相應張量還原回原有形態,這種網絡就叫自動編解碼器。

自動編解碼器的功能很像加解密系統,對加密而言,當把明文進行加密後,形成的密文是一種隨機字符串,再把密文解密後就可以得到明文,解密後的數據必須與加密前的完全一模一樣。自動編解碼器會把輸入的數據,例如是圖片轉換成給定維度的張量,例如一個含有16個元素的一維向量,解碼後它會把對應的含有16個元素的一維向量轉換爲原有圖片,不過轉換後的圖片與原圖片不一定完全一樣,但是圖片內容絕不會有重大改變。

自動編解碼器分爲兩部分,一部分叫encoder,它負責把數據轉換成固定格式,從數學上看,encoder相當於一個函數,被編碼的數據相當於輸入參數,編碼後的張量相當於函數輸出: ,其中f對應encoder,x對應要編碼的數據,例如圖片,z是編碼後的結果。

另一部分叫decoder,也就是把編碼器編碼的結果還原爲原有數據,用數學來表達就是: ,函數g相當於解碼器,它的輸入是編碼器輸出結果, 是解碼器還原結果,它與輸入編碼器的數據可能有差異,但主要內容會保持不變,如圖10-1:

圖10-1 編解碼器運行示意圖

如上圖,手寫數字圖片7經過編碼器後,轉換成給定維度的張量,例如含有16個元素的一維張量,然後經過解碼器處理後還原成一張手寫數字圖片7,還原的圖片與輸入的圖片圖像顯示上有些差異,但是他們都能表達手寫數字7這一含義。 代碼是對原理最好的解釋,我們看看實現過程:

from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.datasets import mnist
from keras.utils import plot_model
from keras import backend as K

import numpy as np
import matplotlib.pyplot as plt

#加載手寫數字圖片數據
(x_train, _), (x_test, _) = mnist.load_data()
image_size = x_train.shape[1]
#把圖片大小統一轉換成28*28,並把像素點值都轉換爲[0,1]之間
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
接下來我們構建自動編解碼器網絡:
#構建解碼器
latent_inputs = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

'''
使用Conv2DTranspose做卷積操作的逆操作。相應的Conv2D做怎樣的計算操作,該網絡層就逆着來
'''
for filters in layer_filters[::-1]:
  x = Conv2DTranspose(filters = filters, kernel_size = kernel_size, 
                     activation='relu', strides = 2, padding='same')(x)

#還原輸入
outputs = Conv2DTranspose(filters = 1, kernel_size = kernel_size, 
                         activation='sigmoid', padding='same', 
                          name='decoder_output')(x)

decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

我們把編碼器和解碼器前後相連,於是數據從編碼器輸入,編碼器將數據進行計算編號後所得的輸出直接傳給解碼器,解碼器進行相對於編碼器的逆運算最後得到類似於輸入編碼器的數據,相應代碼如下:

'''
將編碼器和解碼器前後相連,數據從編碼器輸入,編碼器運算後把結果直接傳遞給解碼器,
解碼器進行編碼器的逆運算,最後輸出與數據輸入時相似的結果
'''
autoencoder = Model(inputs, decoder(encoder(inputs)),
                   name='autoencoder')

autoencoder.compile(loss='mse', optimizer='adam')
'''

在訓練網絡時,輸入數據是x_train,對應標籤也是x_train,這意味着我們希望網絡將輸出儘可能的調整成與輸入一致

'''
autoencoder.fit(x_train, x_train, validation_data=(x_test, x_test), epochs = 1, 
                batch_size = batch_size)

網絡訓練好後,我們把圖片輸入網絡,編碼器把圖片轉換爲含有16個元素的一維向量,然後向量輸入解碼器,解碼器把向量還原爲一張二維圖片,相應代碼如下:

'''
把手寫數字圖片輸入編碼器然後再通過解碼器,檢驗輸出後的圖像與輸出時的圖像是否相似
'''
x_decoded = autoencoder.predict(x_test)
imgs = np.concatenate([x_test[:8], x_decoded[:8]])
imgs = imgs.reshape((4, 4, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Input image: first and second rows, Decoded: third and forth rows')
plt.imshow(imgs, interpolation='none', cmap = 'gray')
plt.savefig('input_and_decoded.png')
plt.show()

上面代碼運行後結果如圖10-2:

上面顯示圖片中,前兩行是輸入編解碼器的手寫數字圖片,後兩行是經過編碼然後還原後的圖片,如果仔細看我們可以發現兩者非常相像,但並不完全一樣,我們看第一行最後一個數字0和解碼後第三行最後一個數字0,兩者有比較明顯差異,但都會被解讀成數字0.

在代碼中需要注意的是,構建解碼器時我們使用了一個類叫Conv2DTranspose,它與Conv2D對應,是後者的反操作,如果把Conv2D看做對輸入數據的壓縮或加密,那麼Conv2DTranspose是對數據的解壓或解密。 另外還需要注意的是,因爲我們網絡層較少,因此訓練時只需要一次循環就好,如果網絡層多的話,我們需要增加循環次數才能使得網絡有良好的輸出效果。

2.使用編解碼器去除圖片噪音

在八零年代,改革開放不久後,一種‘稀有’的家電悄悄潛入很多家庭,那就是錄像機。你把一盤錄像帶推入機器,在電視上就可以把內容播放出來,有一些錄像帶它的磁帶遭到破壞的話,播放時畫面會飄散一系列‘雪花’,我們將那稱之爲畫面‘噪音’。當圖片含有‘噪音’時,圖片表現爲含有很多花點,如圖10-3所示:

圖10-3 含有噪音的圖片

在信號處理這一學科分支中,有很大一部分就在於研究如何去噪,幸運的是通過編解碼網絡也能夠實現圖片噪音去除的效果。本節我們先給手寫數字圖片增加噪音,使得圖片變得很難識別,然後我們再使用編解碼網絡去除圖片噪音,讓圖片回覆原狀。

圖片噪音本質上是在像素點上添加一些隨機值,這裏我們使用高斯分佈產生隨機值,其數學公式如下:

它有兩個決定性參數,分別是μ 和 σ,只要使得這兩個參數取不同的值,我們就可以得到相應分佈的隨機數,其中μ 稱之爲均值, σ稱之爲方差,我們看看如何使用代碼實現圖片加噪,然後構建編解碼網絡去噪音:

#使用高斯分佈產生圖片噪音
np.random.seed(1337)
#使用高斯分佈函數生成隨機數,均值0.5,方差0.5
noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)
x_train_noisy = x_train + noise

noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape)
x_test_noisy = x_test + noise
#把像素點取值範圍轉換到[0,1]間
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
上面的代碼先使用高斯函數產生隨機數,然後加到像素點上從而形成圖片噪音。接着我們看如何構建編解碼器實現圖片去噪:
#構造編解碼網絡,以下代碼與上一小節代碼大部分相同
input_shape = (image_size, image_size, 1)
batch_size = 32
kernel_size = 3
latent_dim = 16
layer_filters = [32, 64]
#構造編碼器
inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for filters in layer_filters:
  x = Conv2D(filters = filters, kernel_size = kernel_size, strides = 2,
            activation='relu', padding='same')(x)

shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')

#構造解碼器
latent_inputs = Input(shape=(latent_dim, ), name='decoder_input')
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for filters in layer_filters[::-1]:
  x = Conv2DTranspose(filters = filters, kernel_size = kernel_size, strides = 2,
                     activation='relu', padding='same')(x)

outputs = Conv2DTranspose(filters=1, kernel_size=kernel_size, padding='same',
                         activation='sigmoid', name='decoder_output')(x)
decoder = Model(latent_inputs, outputs, name='decoder')

#將編碼器和解碼器前後相連
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.compile(loss='mse', optimizer='adam')
#輸入數據是有噪音圖片,對應結果是無噪音圖片
autoencoder.fit(x_train_noisy, x_train, validation_data=(x_test_noisy, x_test),
               epochs = 10, batch_size = batch_size)

代碼中值得注意的是,訓練網絡時,訓練數據時含有噪音的圖片,對應結果是沒有噪音的圖片,也就是我們希望網絡能通過學習自動掌握去噪功能,訓練完成後,我們把測試圖片輸入網絡,看看噪音去除效果:

x_decoded = autoencoder.predict(x_test_noisy)

rows, cols = 3, 9
num = rows * cols
imgs = np.concatenate([x_test[:num], x_test_noisy[:num],
                      x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = (imgs * 255).astype(np.uint8)
plt.figure()
plt.axis('off')
plt.title('Original images: top rows'
          'Corrupted images: middle rows'
         'Denoised Input: third rows')
plt.imshow(imgs, interpolation='none', cmap='gray')
plt.show()

上面代碼運行後如圖10-4所示:

圖10-4 網絡去噪效果 從上圖看,第一行是原圖,第二行是加了噪音的圖片,第三行是網絡去除噪音後的圖片。從上圖看,網絡去噪的效果還是比較完美的。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章