在tensorflow RNN layer的搭建(GRU,LSTM等)中,我們展示瞭如何調用 tensorflow 內置模塊和函數,搭建RNN layer。然而,當一般的GRU/LSTM layer不適用時,我們希望對其 cell 進行改進,實現自主設計的改造版的RNN cell。
這方面研究工作代表的典型有:Time-LSTM,論文鏈接爲:What to Do Next: Modeling User Behaviors by Time-LSTM
下面,我們從tensorflow的內置函數 tf.scan()出發,展示如何自主實現/改造 RNN cell。
tf.scan()
tf.scan(
fn,
elems,
initializer=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
fn : 一個二元函數
elems:一個tensor list
initializer:一個tensor,作爲初始化值
實際上,tf.scan()所能應用的類型不止如此,這裏只舉了我們所需要用到的部分
在tf.scan 記錄中有一個很好的例子,我們借鑑一下:
x = [1,2,3]
z = 10
x = tf.convert_to_tensor(x)
z = tf.convert_to_tensor(z)
def f(x,y):
return x+y
g = tf.scan(fn=f,elems = x,initializer=z)
sess = tf.Session()
sess.run(tf.global_variables_initializer)
sess.run(g)
得到:
In [97]: sess.run(g)
Out[97]: array([11, 13, 16], dtype=int32)
詳細的計算邏輯如下:
11 = 10(初始值initializer)+ 1(x[0])
13 = 11(上次的計算結果)+2(x[1])
16 = 13(上次的計算結果)+3(x[2])
可以發現,tf.scan() 從initializer 開始,把函數 fn 不斷應用在上次計算結果和elems當前的每一個元素上,不斷迭代,得到一系列輸出。
如果我們把elems看作RNN 的輸入seq,把fn 看作cell 的內部作用函數,那麼輸出seq 就是一系列隱狀態[]。這和RNN的作用機制是相同的!
GRU cell實現
下面我們以較爲簡潔典型的GRU cell 爲例,來看tf.scan()的應用
def GRUunit(prev_h, x):
dim_item = tf.shape(x)[1]
dim_hid = DIM_HID
w_xr = tf.get_variable('w_xr', [dim_item, dim_hid])
w_hr = tf.get_variable('w_hr', [dim_hid, dim_hid])
br = tf.get_variable('br', dim_hid)
r = tf.sigmoid(tf.matmul(x, w_xr) + tf.matmul(prev_h, w_hr) + br)
w_xz = tf.get_variable('w_xz', [dim_item, dim_hid])
w_hz = tf.get_variable('w_hz', [dim_hid, dim_hid])
bz = tf.get_variable('bz', dim_hid)
z = tf.sigmoid(tf.matmul(x, w_xz) + tf.matmul(prev_h, w_hz) + bz)
w_xh = tf.get_variable('w_xh', [dim_item, dim_hid])
w_hh = tf.get_variable('w_hh', [dim_hid, dim_hid])
bh = tf.get_variable('bh', dim_hid)
h_ = tf.nn.tanh(tf.matmul(x, w_xh) + tf.matmul(tf.multiply(r, prev_h), w_hh) + bh)
h = tf.multiply(z, h_) + tf.multiply(1-z, prev_h)
return h
def GRUlayer(inputs, layer_name, dim_hid):
with tf.variable_scope(layer_name):
batch_size = tf.shape(inputs)[0]
initial_hidden = tf.zeros([batch_size, dim_hid], tf.float32)
states = tf.scan(GRUunit, tf.transpose(inputs,[1,0,2]), initializer = initial_hidden, name='states')
return tf.transpose(states,[1,0,2]), states[-1,:]
至此,我們就完成了GRU cell 的自主實現。
注意到RNN input 的維度分別爲[batch_size, steps, item_dim],而tf.scan() 是對steps維度展開,因此在輸入和輸出時要對input的前兩維進行轉置。
對於其他RNN cell,只需要對GRUunit函數進行改寫即可。