我們在使用Tensorflow的時候,有時候自帶的激活函數和損失函數不夠用,我們就要自己定義自己的函數。下面我給出一種方法,我試驗可行,當然我也是參考的官方文檔和一些博客。基於tf2.0
自定義損失函數
這是我自己做實驗的時候用到的一個損失函數,我需要把輸出的圖片和標籤圖片計算SSIM的值,然後用1-SSIM的值作爲損失函數值。
def tf_ssim_loss(y_true, y_pred):
total_loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1)
return total_loss
這是第一步。
自定義激活函數
這是我實驗的時候定義的激活函數。具體每個人按自己需要來做。需要說明的是在tf裏面,我們使用的一些像指數函數exp()之類的東西,需要時K.的。因爲這樣纔算是對張量tensor進行計算。因爲默認網絡裏面傳遞的都是tensor.
當然需要在前面import.不過這個問題在python可以自動解決。
from tensorflow.keras import backend as K
def custom_activation_1(x):
cond = K.greater(x, 0)
return K.switch(cond, 1-(1-x**2)*K.exp(-x**2/2), x*K.exp(-x*x/2))
接下來
完成了上面的步驟之後,基本已經弄好了,但是運行的話,tf識別不了我們自定義的激活函數和損失函數。我麼需要在模型建立之前加上這句:
get_custom_objects().update({'custom_activation': Activation(custom_activation_1)})
這樣就可以使用我們的激活函數了。
但是,我們在保存完模型之後,加載模型又會出現問題。應該這樣加載模型。
# 保存模型
model.save('./saved/my_model.h5')
print("saved total mdoel.")
model = tf.keras.models.load_model('./saved/my_model.h5',
custom_objects={'tf_ssim_loss': tf_ssim_loss,
'custom_activation_1': Activation(custom_activation_1)})