神經翻譯筆記4擴展a第二部分. RNN在TF2.0中的實現方法略覽
文章目錄
與TF1.x的實現思路不同,在TF2.0中,RNN已經不再是個函數,而是一個封裝好的類,各種RNN和RNNCell與頂層抽象類
Layer
的關係也更加緊湊(需要說明的是說Layer
頂層並非說它直接繼承自object
,而是從……功能的角度,我覺得可以這麼說。真實實現裏的繼承關係是Layer --> Module --> AutoTrackable --> Trackable --> object
)。但是另一方面,感覺新的版本里各個類的關係稍微有些雜亂,不知道後面會不會進一步重構。TF2.0的RNN相關各類關係大致如下圖所示
相關基類
tf.keras.layers.Layer
與TF1.14的實現基本相同,不再贅述
recurrent.DropoutRNNCellMixin
與之類似的類在TF1.x中以tf.nn.rnn_cell.DropoutWrapper
形式出現,但當時考慮到還沒涉及到RNN的dropout就沒有引入,沒想到在這裏還是要說一說。TF2的實現比TF1的實現要簡單一些,這個類只是維護兩個dropout mask,一個是用於對輸入的mask,一個用於對傳遞狀態的mask(嚴格說是四個,在另一個維度上還考慮是對靜態圖的mask還是對eager模式的mask)。實現保證mask只被創建一次,因此每個batch使用的mask都相同
RNNCell相關
無論是官方給出的文本分類教程,還是我自己從TF1.x改的用更底層API實現的代碼,實際上都沒有用到Cell相關的對象。但是爲了完整起見(畢竟暴露的LSTM
類背後還需要LSTMCell
類對象作爲自己的成員變量),這裏還是稍作介紹
LSTMCell
本文以LSTM爲主,因此先從LSTMCell
說起。與TF1.x不同,在2.x版本里,LSTMCell
允許傳入一個implement
參數,默認爲1,標記LSTM各門和輸出、狀態的計算方式。當取默認的1時,計算方式更像是論文中的方式,逐個計算各個門的結果;而如果設爲2,則使用TF1.x中組合成矩陣一併計算的方式。此外,由於LSTMCell
還繼承了前述DropoutRNNCellMixin
接口,因此可以在call
裏對輸入和上一時間步傳來的狀態做dropout。注意由於LSTM有四個內部變量、、和,因此需要各自生成四個不同的dropout mask
PeepholeLSTMCell
只是改寫了LSTMCell
內部變量的計算邏輯,參見在TF1.x部分的介紹
StackedRNNCells
與TF1.x中的MultiRNNCell
類似
AbstractRNNCell
純抽象類,類似TF1的RNNCell
,如果用戶自己實現一個RNNCell
,需要 可以繼承於它。不過有趣的是內置的三種RNN
實現所使用的Cell:SimpleRNNCell
、GRUCell
、LSTMCell
均直接繼承自Layer
RNN相關
tf.keras.layers.RNN
所有後續RNN相關類的基類,承擔TF1.x中static_rnn
和dynamic_rnn
的雙重功能,主要邏輯分別集中在初始化函數__init__
、build
和call
中(__call__
也有一些邏輯,但是隻針對某些特殊情況)
RNN
在初始化時傳入的參數個人感覺相對來講不如1.x直觀。其允許傳入的參數包括
cell
:一種RNNCell的對象,也可以是列表或元組。當傳入的參數爲列表或元組時,會打包組合爲StackedRNNCells
類對象return_sequences
:默認RNN只返回最後一個時間步的輸出。當此參數設爲True
時,返回每個時間步的輸出return_state
:當此參數設爲True
時,返回最終狀態go_backwards
:當此參數設爲True
時,將輸入逆序處理stateful
:當此參數設爲True
時,每個batch第i個樣本的最終狀態會作爲下個batch第i個樣本的初始狀態unroll
:當此參數設爲True
時,相當於1.x版本中的static rnn,網絡被展開。文檔認爲展開網絡可以加速RNN,但顯然代價是使用的顯存資源會變多time_major
:當此參數設爲True
時,第一個維度爲時間維;否則爲batch維zero_output_for_mask
:沒有在接口中直接暴露出來,而是隱藏在**kwargs
中。當此參數設爲True
時,mask對應的時間步輸出都爲0,否則照搬前一個時間步的輸出
build
實際只是調用cell
的build
方法,並做一些校驗
def build(self, input_shape):
step_input_shape = get_step_input_shape(input_shape)
if not self.cell.built:
self.cell.build(step_input_shape)
self._set_state_spec(state_size)
if self.stateful:
self.reset_states()
self.built = True
call
的核心是調用keras後端方法keras.backend.rnn
(K.rnn
)
def _process_inputs(inputs, initial_state, ...):
if initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
else:
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
if get_initial_state_fn:
initial_state = get_initial_state_fn()
else:
initial_state = zero_state
return inputs, initial_state, ...
def call(self, inputs, ...):
inputs, initial_state, ... = self._process_inputs(inputs, initial_state, ...)
def step(inputs, states):
output, new_states = self.cell.call(inputs, states)
last_output, outputs, states = K.rnn(step, inputs, initial_state, ...)
if self.stateful:
updates = [assign_op(old, new) for old, new in zip(self.states, states)
self.add_update(updates)
if self.return_sequences:
output = outputs
else:
output = last_output
if self.return_state:
return to_list(output) + states
return output
K.rnn
對RNN是否展開(unroll)和是否需要mask有不同的邏輯,這裏只列出不展開且有mask的邏輯。個人感覺和1.x版本中dynamic_rnn
的實現方法大同小異
def rnn(step_function, inputs, initial_states, ...):
# 轉換成time major
inputs = swap_batch_timestep(inputs)
mask = swap_batch_timestep(mask)
time_steps_t = inputs[0].shape[0]
input_ta = TensorArray(inputs)
output_ta = TensorArray(shape=inputs[0].shape)
mask_ta = TensorArray(mask)
states = tuple(initial_states)
prev_output = 0
time = 0
while time < time_steps_t:
current_input = input_ta[time]
mask_t = mask_ta[time]
output, new_states = step_function(current_input, states)
mask_output = 0 if zero_output_for_mask else prev_output
new_output = where(mask, output, mask_output)
new_states = where(mask, new_states, states)
output_ta.append(new_output)
prev_output, states = new_output, new_states
time += 1
return output_ta[-1], output_ta, states
recurrent.LSTM
與父類相比實際上只額外做了兩件事
- 初始化時cell固定爲
LSTMCell
- 調用父類的
call
之前先重置兩個dropout mask
recurrent_v2.LSTM
使用tf.keras.layers.LSTM
類對象時實際使用的類,之所以帶“v2”是因爲整合了CuDNN的實現,所以理論上速度會更快,效率會更高。不過使用時需加入如下兩行代碼
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
tf.keras.layers.Bidirectional
與RNN類似地,在TF2.0裏雙向RNN也不再實現爲函數,而是實現爲一個Layer
對象的包裝器,爲Layer
對象提供一定的額外功能。由於Bidirectional
也是間接繼承自Layer
類,因此其大部分邏輯也是蘊含在call
方法中
初始化Bidirectional
主要需要傳入一個Layer
類對象layer
——不過從實現來看,這個類對象應該還是要是RNN
或者其子類的對象。可選的三個字段包括:
merge_mode
,指定正向和反向RNN的輸出如何組合,可以是如下幾種選擇:求和sum
、逐元素相乘mul
、直接相連concat
、求均值ave
或直接返回兩個輸出組成的一個列表None
weights
,指定兩個RNN的初始化權重backward_layer
:允許用戶直接傳入已經反向的RNN。如果backward_layer
爲None
(默認情況),Bidirectional
在初始化時會先根據layer
對象的config重構一個RNN,再使用相同的配置構建對應的反向RNN。Bidrectional
會強制讓自己的兩個RNN成員對被mask掉的部分輸出爲0(zero_output_for_mask
強制爲True
)
Bidirectional
的build
實際上就是調用兩個RNN成員的build
。對應地,call
方法也是調用兩個RNN成員的call
然後根據指定的merge_mode
組合輸出。源代碼看上去略長是因爲處理了多個輸入和初始狀態不爲空的情況,而常見的單輸入無初始狀態下,邏輯相對直觀,大致如下:
def call(self, inputs):
y = self.forward_layer(inputs, **kwargs)
y_rev = self.backward_layer(inputs, **kwargs)
if self.return_state:
states = y[1:] + y_rev[1:]
y, y_rev = y[0], y_rev[0]
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
output = K.concatenate([y, y_rev])
elif self.merge_mode == 'sum':
output = y + y_rev
elif self.merge_mode == 'ave':
output = (y + y_rev) / 2
elif self.merge_mode == 'mul':
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
else:
raise ValueError
if self.return_state:
if self.merge_mode is None:
return output + states
return [output] + states
return output
後記
與前一篇文章相比,本文顯得有些粗糙。不過這也是意料之中的事情:LSTM的原理並不會因爲它是TF1還是TF2發生變化,因此實現也不會有太大的變化,變的只會是類的組織方式。料想下一篇討論PyTorch的文章,更多也會集中在結構設計上,畢竟具體實現已經在前一篇文章裏描述得差不多了