Tensorflow 自定义激活函数和损失函数

\qquad我们在使用Tensorflow的时候,有时候自带的激活函数和损失函数不够用,我们就要自己定义自己的函数。下面我给出一种方法,我试验可行,当然我也是参考的官方文档和一些博客。基于tf2.0

自定义损失函数

\qquad这是我自己做实验的时候用到的一个损失函数,我需要把输出的图片和标签图片计算SSIM的值,然后用1-SSIM的值作为损失函数值。

def tf_ssim_loss(y_true, y_pred):
    total_loss = 1 - tf.image.ssim(y_true, y_pred, max_val=1)
    return total_loss

\qquad这是第一步。

自定义激活函数

\qquad这是我实验的时候定义的激活函数。具体每个人按自己需要来做。需要说明的是在tf里面,我们使用的一些像指数函数exp()之类的东西,需要时K.的。因为这样才算是对张量tensor进行计算。因为默认网络里面传递的都是tensor.
\qquad当然需要在前面import.不过这个问题在python可以自动解决。

from tensorflow.keras import backend as K
def custom_activation_1(x):
    cond = K.greater(x, 0)
    return K.switch(cond, 1-(1-x**2)*K.exp(-x**2/2), x*K.exp(-x*x/2))

接下来

\qquad完成了上面的步骤之后,基本已经弄好了,但是运行的话,tf识别不了我们自定义的激活函数和损失函数。我么需要在模型建立之前加上这句:

get_custom_objects().update({'custom_activation': Activation(custom_activation_1)})

\qquad这样就可以使用我们的激活函数了。
\qquad但是,我们在保存完模型之后,加载模型又会出现问题。应该这样加载模型。

# 保存模型
model.save('./saved/my_model.h5')
print("saved total mdoel.")
model = tf.keras.models.load_model('./saved/my_model.h5',
                                   custom_objects={'tf_ssim_loss': tf_ssim_loss,
                                                   'custom_activation_1': Activation(custom_activation_1)})
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章