Notes
要實現 [1] 的 piece-wise threshold function,類似於 Htanh,也需要自定義梯度,用到 @tf.custom_gradient
。
函數是:
定義其導數:
其中 是超參,訓練時會變,用 placeholder 傳參。
Codes
import tensorflow as tf
import numpy as np
@tf.custom_gradient
def pw_threshold(x, epsilon):
"""piece-wise threshold"""
cond_org = ((0.5 - epsilon) <= x) & (x < (0.5 + epsilon))
cond_one = x >= (0.5 + epsilon)
ones = tf.ones_like(x)
zeros = tf.zeros_like(x)
y = tf.where(cond_org, x, zeros) + \
tf.where(cond_one, ones, zeros)
def grad(dy):
cond = ((0.5 - epsilon) <= x) & (x < (0.5 + epsilon))
zeros = tf.zeros_like(dy)
# 返回的 epsilon 沒用,但需要這樣,有幾個輸入就對應返回幾個梯度
return tf.where(cond, dy, zeros), epsilon
return y, grad
# 測試
epsilon = tf.placeholder("float64", [])
x = tf.constant(np.arange(-0.25, 1.26, 0.25))
y = pw_threshold(x, epsilon)
grad = tf.gradients(y, x)
with tf.Session() as sess:
print("x:", sess.run(x))
print("y:", sess.run(y, feed_dict={epsilon: 0.25}))
print("grad:", sess.run(grad, feed_dict={epsilon: 0.25}))
輸出:
x: [-0.25 0. 0.25 0.5 0.75 1. 1.25]
y: [0. 0. 0.25 0.5 1. 1. 1. ]
grad: [array([0., 0., 1., 1., 0., 0., 0.])]