Tensorflow 自定義激活函數和損失函數

\qquad我們在使用Tensorflow的時候,有時候自帶的激活函數和損失函數不夠用,我們就要自己定義自己的函數。下面我給出一種方法,我試驗可行,當然我也是參考的官方文檔和一些博客。基於tf2.0

自定義損失函數

\qquad這是我自己做實驗的時候用到的一個損失函數,我需要把輸出的圖片和標籤圖片計算SSIM的值,然後用1-SSIM的值作爲損失函數值。

def tf_ssim_loss(y_true, y_pred):
    total_loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1)
    return total_loss

\qquad這是第一步。

自定義激活函數

\qquad這是我實驗的時候定義的激活函數。具體每個人按自己需要來做。需要說明的是在tf裏面,我們使用的一些像指數函數exp()之類的東西,需要時K.的。因爲這樣纔算是對張量tensor進行計算。因爲默認網絡裏面傳遞的都是tensor.
\qquad當然需要在前面import.不過這個問題在python可以自動解決。

from tensorflow.keras import backend as K
def custom_activation_1(x):
    cond = K.greater(x, 0)
    return K.switch(cond, 1-(1-x**2)*K.exp(-x**2/2), x*K.exp(-x*x/2))

接下來

\qquad完成了上面的步驟之後,基本已經弄好了,但是運行的話,tf識別不了我們自定義的激活函數和損失函數。我麼需要在模型建立之前加上這句:

get_custom_objects().update({'custom_activation': Activation(custom_activation_1)})

\qquad這樣就可以使用我們的激活函數了。
\qquad但是,我們在保存完模型之後,加載模型又會出現問題。應該這樣加載模型。

# 保存模型
model.save('./saved/my_model.h5')
print("saved total mdoel.")
model = tf.keras.models.load_model('./saved/my_model.h5',
                                   custom_objects={'tf_ssim_loss': tf_ssim_loss,
                                                   'custom_activation_1': Activation(custom_activation_1)})
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章