神經翻譯筆記4擴展a第二部分. RNN在TF2.0中的實現方法略覽

神經翻譯筆記4擴展a第二部分. RNN在TF2.0中的實現方法略覽


與TF1.x的實現思路不同,在TF2.0中,RNN已經不再是個函數,而是一個封裝好的類,各種RNN和RNNCell與頂層抽象類Layer的關係也更加緊湊(需要說明的是說Layer頂層並非說它直接繼承自object,而是從……功能的角度,我覺得可以這麼說。真實實現裏的繼承關係是Layer --> Module --> AutoTrackable --> Trackable --> object)。但是另一方面,感覺新的版本里各個類的關係稍微有些雜亂,不知道後面會不會進一步重構。TF2.0的RNN相關各類關係大致如下圖所示

在這裏插入圖片描述

相關基類

tf.keras.layers.Layer

與TF1.14的實現基本相同,不再贅述

recurrent.DropoutRNNCellMixin

與之類似的類在TF1.x中以tf.nn.rnn_cell.DropoutWrapper形式出現,但當時考慮到還沒涉及到RNN的dropout就沒有引入,沒想到在這裏還是要說一說。TF2的實現比TF1的實現要簡單一些,這個類只是維護兩個dropout mask,一個是用於對輸入的mask,一個用於對傳遞狀態的mask(嚴格說是四個,在另一個維度上還考慮是對靜態圖的mask還是對eager模式的mask)。實現保證mask只被創建一次,因此每個batch使用的mask都相同

RNNCell相關

無論是官方給出的文本分類教程,還是我自己從TF1.x改的用更底層API實現的代碼,實際上都沒有用到Cell相關的對象。但是爲了完整起見(畢竟暴露的LSTM類背後還需要LSTMCell類對象作爲自己的成員變量),這裏還是稍作介紹

LSTMCell

本文以LSTM爲主,因此先從LSTMCell說起。與TF1.x不同,在2.x版本里,LSTMCell允許傳入一個implement參數,默認爲1,標記LSTM各門和輸出、狀態的計算方式。當取默認的1時,計算方式更像是論文中的方式,逐個計算各個門的結果;而如果設爲2,則使用TF1.x中組合成矩陣一併計算的方式。此外,由於LSTMCell還繼承了前述DropoutRNNCellMixin接口,因此可以在call裏對輸入和上一時間步傳來的狀態做dropout。注意由於LSTM有四個內部變量i\boldsymbol{i}f\boldsymbol{f}o\boldsymbol{o}c~\tilde{\boldsymbol{c}},因此需要各自生成四個不同的dropout mask

PeepholeLSTMCell

只是改寫了LSTMCell內部變量的計算邏輯,參見在TF1.x部分的介紹

StackedRNNCells

與TF1.x中的MultiRNNCell類似

AbstractRNNCell

純抽象類,類似TF1的RNNCell,如果用戶自己實現一個RNNCell需要 可以繼承於它。不過有趣的是內置的三種RNN實現所使用的Cell:SimpleRNNCellGRUCellLSTMCell均直接繼承自Layer

RNN相關

tf.keras.layers.RNN

所有後續RNN相關類的基類,承擔TF1.x中static_rnndynamic_rnn的雙重功能,主要邏輯分別集中在初始化函數__init__buildcall中(__call__也有一些邏輯,但是隻針對某些特殊情況)

RNN在初始化時傳入的參數個人感覺相對來講不如1.x直觀。其允許傳入的參數包括

  • cell:一種RNNCell的對象,也可以是列表或元組。當傳入的參數爲列表或元組時,會打包組合爲StackedRNNCells類對象
  • return_sequences:默認RNN只返回最後一個時間步的輸出。當此參數設爲True時,返回每個時間步的輸出
  • return_state:當此參數設爲True時,返回最終狀態
  • go_backwards:當此參數設爲True時,將輸入逆序處理
  • stateful:當此參數設爲True時,每個batch第i個樣本的最終狀態會作爲下個batch第i個樣本的初始狀態
  • unroll:當此參數設爲True時,相當於1.x版本中的static rnn,網絡被展開。文檔認爲展開網絡可以加速RNN,但顯然代價是使用的顯存資源會變多
  • time_major:當此參數設爲True時,第一個維度爲時間維;否則爲batch維
  • zero_output_for_mask:沒有在接口中直接暴露出來,而是隱藏在**kwargs中。當此參數設爲True時,mask對應的時間步輸出都爲0,否則照搬前一個時間步的輸出

build實際只是調用cellbuild方法,並做一些校驗

def build(self, input_shape):
    step_input_shape = get_step_input_shape(input_shape)
    if not self.cell.built:
        self.cell.build(step_input_shape)
    self._set_state_spec(state_size)
    if self.stateful:
        self.reset_states()
    self.built = True

call的核心是調用keras後端方法keras.backend.rnnK.rnn

def _process_inputs(inputs, initial_state, ...):
    if initial_state is not None:
        pass
    elif self.stateful:
        initial_state = self.states
    else:
        get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
        if get_initial_state_fn:
            initial_state = get_initial_state_fn()
        else:
            initial_state = zero_state
    return inputs, initial_state, ...

def call(self, inputs, ...):
    inputs, initial_state, ... = self._process_inputs(inputs, initial_state, ...)
    def step(inputs, states):
        output, new_states = self.cell.call(inputs, states)
    last_output, outputs, states = K.rnn(step, inputs, initial_state, ...)
    if self.stateful:
        updates = [assign_op(old, new) for old, new in zip(self.states, states)
        self.add_update(updates)
    if self.return_sequences:
        output = outputs
    else:
        output = last_output
    
    if self.return_state:
        return to_list(output) + states
    return output

K.rnn對RNN是否展開(unroll)和是否需要mask有不同的邏輯,這裏只列出不展開且有mask的邏輯。個人感覺和1.x版本中dynamic_rnn的實現方法大同小異

def rnn(step_function, inputs, initial_states, ...):
    # 轉換成time major
    inputs = swap_batch_timestep(inputs)
    mask = swap_batch_timestep(mask)
    time_steps_t = inputs[0].shape[0]
   
    input_ta = TensorArray(inputs)
    output_ta = TensorArray(shape=inputs[0].shape)
    mask_ta = TensorArray(mask)
    states = tuple(initial_states)
    prev_output = 0
    time = 0
    while time < time_steps_t:
        current_input = input_ta[time]
        mask_t = mask_ta[time]
        output, new_states = step_function(current_input, states)
        mask_output = 0 if zero_output_for_mask else prev_output
        new_output = where(mask, output, mask_output)
        new_states = where(mask, new_states, states)
        output_ta.append(new_output)
        prev_output, states = new_output, new_states
        time += 1
    return output_ta[-1], output_ta, states

recurrent.LSTM

與父類相比實際上只額外做了兩件事

  • 初始化時cell固定爲LSTMCell
  • 調用父類的call之前先重置兩個dropout mask

recurrent_v2.LSTM

使用tf.keras.layers.LSTM類對象時實際使用的類,之所以帶“v2”是因爲整合了CuDNN的實現,所以理論上速度會更快,效率會更高。不過使用時需加入如下兩行代碼

physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

tf.keras.layers.Bidirectional

與RNN類似地,在TF2.0裏雙向RNN也不再實現爲函數,而是實現爲一個Layer對象的包裝器,爲Layer對象提供一定的額外功能。由於Bidirectional也是間接繼承自Layer類,因此其大部分邏輯也是蘊含在call方法中

初始化Bidirectional主要需要傳入一個Layer類對象layer——不過從實現來看,這個類對象應該還是要是RNN或者其子類的對象。可選的三個字段包括:

  • merge_mode,指定正向和反向RNN的輸出如何組合,可以是如下幾種選擇:求和sum、逐元素相乘mul、直接相連concat、求均值ave或直接返回兩個輸出組成的一個列表None
  • weights,指定兩個RNN的初始化權重
  • backward_layer:允許用戶直接傳入已經反向的RNN。如果backward_layerNone(默認情況),Bidirectional在初始化時會先根據layer對象的config重構一個RNN,再使用相同的配置構建對應的反向RNN。Bidrectional會強制讓自己的兩個RNN成員對被mask掉的部分輸出爲0(zero_output_for_mask強制爲True

Bidirectionalbuild實際上就是調用兩個RNN成員的build。對應地,call方法也是調用兩個RNN成員的call然後根據指定的merge_mode組合輸出。源代碼看上去略長是因爲處理了多個輸入和初始狀態不爲空的情況,而常見的單輸入無初始狀態下,邏輯相對直觀,大致如下:

def call(self, inputs):
    y = self.forward_layer(inputs, **kwargs)
    y_rev = self.backward_layer(inputs, **kwargs)

    if self.return_state:
        states = y[1:] + y_rev[1:]
        y, y_rev = y[0], y_rev[0]
    if self.return_sequences:
        y_rev = K.reverse(y_rev, 1)

    if self.merge_mode == 'concat':
        output = K.concatenate([y, y_rev])
    elif self.merge_mode == 'sum':
        output = y + y_rev
    elif self.merge_mode == 'ave':
        output = (y + y_rev) / 2
    elif self.merge_mode == 'mul':
        output = y * y_rev
    elif self.merge_mode is None:
        output = [y, y_rev]
    else:
        raise ValueError
    
    if self.return_state:
        if self.merge_mode is None:
            return output + states
        return [output] + states
    return output

後記

與前一篇文章相比,本文顯得有些粗糙。不過這也是意料之中的事情:LSTM的原理並不會因爲它是TF1還是TF2發生變化,因此實現也不會有太大的變化,變的只會是類的組織方式。料想下一篇討論PyTorch的文章,更多也會集中在結構設計上,畢竟具體實現已經在前一篇文章裏描述得差不多了

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章