【Tensorflow】自主實現包含全節點Cell的LSTM層(Cell-Holding LSTM Layer)

0x00 前言

常用的LSTM,或是雙向LSTM,輸出的結果通常是以下兩個:
1) outputs,包括所有節點的hidden
2) 末節點的state,包括末節點的hidden和cell
大部分任務有這些就足夠了,state是隨着節點間信息的傳遞依次變化並容納更多信息,
所以通常末狀態的cell就囊括了所有信息,不需要中間每個節點的cell信息,
但如果我們的研究過程中需要用到這些cell該如何是好呢?

近期的任務中,需要每個節點的前後節點cell信息來做某種判斷,
所以屬於一個較爲特殊的任務,自主實現了一下這個同樣也會反饋cell的LSTM,
哦順帶一提Cell-Holding,是強行爲了簡稱成CHD取的名字(笑)

0x01 分析與設計

首先分析源碼,看一下通常LSTM層調用使用 dynamic_rnn 的實現邏輯,
原邏輯大概是這樣的:

outputs = []
state = Cell.zero_state(N, tf.float32)  # state = (hidden, cell)
for input in inputs:
    output, state = Cell(input, state)  # hidden, (hidden, cell) = Cell()
    outputs.append(output)  # outputs.append(hidden)
return outputs, state  # outputs := a list of (hidden)

那麼其實……我們只需要重新實現一個簡化的版本,讓cell留下來即可。
此處使用的邏輯大概是這樣的:

states_case = []
state = Cell.zero_state(N, tf.float32)  # state = (hidden, cell)
for input in inputs:
    output, state = Cell(input, state)  # hidden, (hidden, cell) = Cell()
    outputs.append(output)  # states_case.append((hidden, cell))
return states_case  # states_case := list of (hidden, cell)

爲了實現這些,就需要做到以下幾件事情:
1) 獲取或共享已有LSTM層的BasicLSTMCell
2) 編寫Cell相關計算,保留LSTM計算途中的信息,可自定義獲取輸出的格式
3) 採用設計的輸出格式使用這些節點信息,以完成其他任務

0x02 Source Code

Advanced LSTM Layer

[LstmLayer] in tf_layers
首先要在不影響功能的情況下改寫原有的LSTM Layer,令其支持獲取BasicCell的操作

class LstmLayer(object):
    # based on LSTM Layer, thanks for @lhw446
    def __init__(self, input_dim, num_units, sequence_length=None, bidirection=False, name="lstm"):
        self.input_dim = input_dim
        self.num_units = num_units
        self.bidirection = bidirection
        self.sequence_length = sequence_length
        self.name = name

        # `with ... as...` remains assignment work.
        self.lstm_fw_cell = None
        self.lstm_bw_cell = None

        with tf.name_scope('%s_def' % (self.name)):
            self.lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)
            if self.bidirection:
                self.lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, state_is_tuple=True)


    def __call__(self, inputs, sequence_length=None, time_major=False,
                 initial_state_fw=None, initial_state_bw=None):
        inputs_shape = tf.shape(inputs)
        inputs = tf.reshape(inputs, [-1, inputs_shape[-2], self.input_dim])
        sequence_length = self.sequence_length if sequence_length is None \
            else tf.reshape(sequence_length, [-1])

        if initial_state_fw is not None:
            initial_state_fw = tf.nn.rnn_cell.LSTMStateTuple(
                tf.reshape(initial_state_fw[0], [-1, self.num_units]),
                tf.reshape(initial_state_fw[1], [-1, self.num_units]))
        if initial_state_bw is not None:
            initial_state_bw = tf.nn.rnn_cell.LSTMStateTuple(
                tf.reshape(initial_state_bw[0], [-1, self.num_units]),
                tf.reshape(initial_state_bw[1], [-1, self.num_units]))

        resh_1 = lambda tensors: tf.reshape(
            tensors, tf.concat([inputs_shape[:-1], [tf.shape(tensors)[-1]]], 0))
        resh_2 = lambda tensors: tf.reshape(
            tensors, tf.concat([inputs_shape[:-2], [tf.shape(tensors)[-1]]], 0))

        with tf.variable_scope('%s_cal' % (self.name)):
            if self.bidirection:
                outputs, output_states = tf.nn.bidirectional_dynamic_rnn(
                    self.lstm_fw_cell, self.lstm_bw_cell, inputs,
                    sequence_length=sequence_length,
                    initial_state_fw=initial_state_fw,
                    initial_state_bw=initial_state_bw,
                    time_major=time_major, dtype=tf.float32)
                # (fw_outputs, bw_outputs)
                outputs = tf.nn.rnn_cell.LSTMStateTuple(resh_1(outputs[0]), resh_1(outputs[1]))
                # ((fw_c_states, fw_m_states), (bw_c_states, bw_m_states))
                output_states = tf.nn.rnn_cell.LSTMStateTuple(
                    tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[0][0]), resh_2(output_states[0][1])),
                    tf.nn.rnn_cell.LSTMStateTuple(resh_2(output_states[1][0]), resh_2(output_states[1][1])))
            else:
                outputs, output_states = tf.nn.dynamic_rnn(
                    self.lstm_fw_cell, inputs, sequence_length=sequence_length,
                    initial_state=initial_state_fw,
                    time_major=time_major, dtype=tf.float32)
                outputs = resh_1(outputs)  # (outputs)
                # (c_states, m_states)
                output_states = tf.nn.rnn_cell.LSTMStateTuple(
                    resh_2(output_states[0]), resh_2(output_states[1]))

            return outputs, output_states

Cell-HolDing Layer

chd_lstm_layer in network
然後基於目標LSTM層,構建使用相同基本單元的scope,設定初始零狀態,逐層計算
(此處僅剪枝了所有的padding位,沒有特意做加速,用了簡單的python-like的for循環)
(且爲了本次實驗需要,沒有將hidden和cell區分開來,而是直接保存了state整體,可自行修改)

def chd_lstm_layer(self, inputs, target_layer):
    cell = target_layer.lstm_fw_cell

    with tf.variable_scope('%s_cal' % (target_layer.name)):
        # generate initial states for current inputs
        states_case = []
        for batch_idx in range(self.batch_size):
            batch_state_case = []
            state = cell.zero_state(1, tf.float32)
            for time_step in range(self.seg_len[batch_idx]):
                tf_input = inputs[batch_idx, time_step]
                output, _state = cell(
                    tf.reshape(tf_input, [1, -1]), state)
                batch_state_case.append(_state)
                state = _state
            states_case.append(batch_state_case)
        # a nested list of states [batch_size, seg_len]
        return states_case, cell

上述是任務需要,
主要演示了可以簡單的循環調用給定LSTM層的Cell進行計算,
在對齊的情況下還可以通過stack等操作拼成一個tf的矩陣使用。
其中用作循環迭代次數的參數 self.batch_size self.seg_len等,
不可以是tf.placeholder,因爲range內必須爲一個固定的數值而不能爲一個佔位符(tf.loop不知道能不能做到)
所以在feed_dict前,我做了如下的操作,將這些固定數值作爲 instance_variables 傳給網絡以供使用。

def gen_infer_inputs(self, data):
    # data = merge_by_batch_size(batch_data_generate(data))
    self.batch_size = data['cell_lens'].shape[0]
    self.seg_len = data['cell_lens']
    self.can_len = data['candi_mask'].sum(-1)
    return {
        self.input_data: data['input_data'],
        self.cell_lens: data['cell_lens'],
        self.candidates: data['candidates'],
        self.candi_mask: data['candi_mask'],
        self.keep_prob: 1.0,
    }

Further usage on states_case

others_layer in network
獲取了states_case之後,可以用於各個位置的使用
下文中給出一個使用案例,此處用於計算相同LSTM序列中,替換其中任意節點爲其他節點的輸出。

def replace_layer(self, forward_emb, candidate_emb):
    backward_emb = self.get_reverse(forward_emb, rev_length=self.cell_lens + 2)

    fw_states, fw_cell = self.chd_lstm_layer(
        forward_emb, self.forward_lstm)
    bw_states, bw_cell = self.chd_lstm_layer(
        backward_emb, self.backward_lstm)

    hidden_case = []
    for batch_idx in range(self.batch_size):
        batch_case = []
        for time_step in range(self.seg_len[batch_idx]):
            time_case = []
            for candidate_idx in range(self.can_len[batch_idx, time_step]):
                tf_input = candidate_emb[batch_idx, time_step, candidate_idx]
                fw_hidden, _ = fw_cell(
                    tf.reshape(tf_input, [1, -1]),
                    fw_states[batch_idx][time_step])
                bw_hidden, _ = bw_cell(
                    tf.reshape(tf_input, [1, -1]),
                    bw_states[batch_idx][-time_step])
                hidden = tf.concat([fw_hidden, bw_hidden], -1)
                time_case.append(hidden)
            batch_case.append(time_case)
        hidden_case.append(batch_case)
    return hidden_case  # a nested list.

0x03 後記

cell因其持續更新且後者包含前者信息的特性通常不被保存,
但是 LSTMCell RNNCell 的調用卻需要完整的state(包括hiddencell),
在我們對已經計算完畢的LSTM序列中內部的某些節點有所想法時,就很難回溯了,
所以說不定這種layer也是有一定價值的,目前tensorflow裏還沒有整合成類似的層,
所以自行手寫了一個,雖說不是太複雜,不過提供了這樣一種想法,記錄一下~
(說不定以後就加了這個層呢~ 到時候我可以指着這篇文章說我早就想到咯^_^)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章