用 tf.scan() 自主實現/改造 RNN cell (GRU/LSTM)

tensorflow RNN layer的搭建(GRU,LSTM等)中,我們展示瞭如何調用 tensorflow 內置模塊和函數,搭建RNN layer。然而,當一般的GRU/LSTM layer不適用時,我們希望對其 cell 進行改進,實現自主設計的改造版的RNN cell。
這方面研究工作代表的典型有:Time-LSTM,論文鏈接爲:What to Do Next: Modeling User Behaviors by Time-LSTM
下面,我們從tensorflow的內置函數 tf.scan()出發,展示如何自主實現/改造 RNN cell。

tf.scan()

tf.scan(
    fn,
    elems,
    initializer=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

fn : 一個二元函數
elems:一個tensor list
initializer:一個tensor,作爲初始化值
實際上,tf.scan()所能應用的類型不止如此,這裏只舉了我們所需要用到的部分
tf.scan 記錄中有一個很好的例子,我們借鑑一下:

x = [1,2,3]
z = 10

x = tf.convert_to_tensor(x)
z = tf.convert_to_tensor(z)

def f(x,y):
    return x+y

g = tf.scan(fn=f,elems = x,initializer=z)

sess = tf.Session()
sess.run(tf.global_variables_initializer)

sess.run(g)

得到:

In [97]: sess.run(g)
Out[97]: array([11, 13, 16], dtype=int32)

詳細的計算邏輯如下:
11 = 10(初始值initializer)+ 1(x[0])
13 = 11(上次的計算結果)+2(x[1])
16 = 13(上次的計算結果)+3(x[2])
可以發現,tf.scan() 從initializer 開始,把函數 fn 不斷應用在上次計算結果和elems當前的每一個元素上,不斷迭代,得到一系列輸出。

如果我們把elems看作RNN 的輸入seq,把fn 看作cell 的內部作用函數,那麼輸出seq 就是一系列隱狀態[h1,h2,,hNh_1, h_2, \cdots, h_N]。這和RNN的作用機制是相同的!

GRU cell實現

下面我們以較爲簡潔典型的GRU cell 爲例,來看tf.scan()的應用

def GRUunit(prev_h, x):

    dim_item = tf.shape(x)[1]
    dim_hid = DIM_HID

    w_xr = tf.get_variable('w_xr', [dim_item, dim_hid])
    w_hr = tf.get_variable('w_hr', [dim_hid, dim_hid])
    br = tf.get_variable('br', dim_hid)

    r = tf.sigmoid(tf.matmul(x, w_xr) + tf.matmul(prev_h, w_hr) + br)

    w_xz = tf.get_variable('w_xz', [dim_item, dim_hid])
    w_hz = tf.get_variable('w_hz', [dim_hid, dim_hid])
    bz = tf.get_variable('bz', dim_hid)

    z = tf.sigmoid(tf.matmul(x, w_xz) + tf.matmul(prev_h, w_hz) + bz)

    w_xh = tf.get_variable('w_xh', [dim_item, dim_hid])
    w_hh = tf.get_variable('w_hh', [dim_hid, dim_hid])
    bh = tf.get_variable('bh', dim_hid)

    h_ = tf.nn.tanh(tf.matmul(x, w_xh) + tf.matmul(tf.multiply(r, prev_h), w_hh) + bh)

    h = tf.multiply(z, h_) + tf.multiply(1-z, prev_h)

    return h
def GRUlayer(inputs, layer_name, dim_hid):

    with tf.variable_scope(layer_name):

        batch_size = tf.shape(inputs)[0]
        initial_hidden = tf.zeros([batch_size, dim_hid], tf.float32)

        states = tf.scan(GRUunit, tf.transpose(inputs,[1,0,2]), initializer = initial_hidden, name='states')

    return tf.transpose(states,[1,0,2]), states[-1,:]

至此,我們就完成了GRU cell 的自主實現。
注意到RNN input 的維度分別爲[batch_size, steps, item_dim],而tf.scan() 是對steps維度展開,因此在輸入和輸出時要對input的前兩維進行轉置。
對於其他RNN cell,只需要對GRUunit函數進行改寫即可。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章