Notes
要实现 [1] 的 piece-wise threshold function,类似于 Htanh,也需要自定义梯度,用到 @tf.custom_gradient
。
函数是:
定义其导数:
其中 是超参,训练时会变,用 placeholder 传参。
Codes
import tensorflow as tf
import numpy as np
@tf.custom_gradient
def pw_threshold(x, epsilon):
"""piece-wise threshold"""
cond_org = ((0.5 - epsilon) <= x) & (x < (0.5 + epsilon))
cond_one = x >= (0.5 + epsilon)
ones = tf.ones_like(x)
zeros = tf.zeros_like(x)
y = tf.where(cond_org, x, zeros) + \
tf.where(cond_one, ones, zeros)
def grad(dy):
cond = ((0.5 - epsilon) <= x) & (x < (0.5 + epsilon))
zeros = tf.zeros_like(dy)
# 返回的 epsilon 没用,但需要这样,有几个输入就对应返回几个梯度
return tf.where(cond, dy, zeros), epsilon
return y, grad
# 测试
epsilon = tf.placeholder("float64", [])
x = tf.constant(np.arange(-0.25, 1.26, 0.25))
y = pw_threshold(x, epsilon)
grad = tf.gradients(y, x)
with tf.Session() as sess:
print("x:", sess.run(x))
print("y:", sess.run(y, feed_dict={epsilon: 0.25}))
print("grad:", sess.run(grad, feed_dict={epsilon: 0.25}))
输出:
x: [-0.25 0. 0.25 0.5 0.75 1. 1.25]
y: [0. 0. 0.25 0.5 1. 1. 1. ]
grad: [array([0., 0., 1., 1., 0., 0., 0.])]