RNN在深度學習中佔據重要地位,我們常常調用tensorflow的包就可以完成RNN的構建與訓練,但通用的RNN並不總是能滿足我們的需求,若要改動,必先知其細。
也許你會說,我自己用for循環寫個rnn的實現不就好了嘛,當然可以啊。但內置的函數一般都比for循環快,用 while_loop 的好處是速度快效率高,因爲它是一個tf的內置運算,會構建入運算圖的,循環運行的時候不會再與python作交互。
下面我們根據源碼對RNN的實現一探究竟。在探究之前,先來說一下什麼叫動態RNN,我們都知道RNN全名是循環神經網絡,循環嘛,自然是動態的,通過循環的方式動態生成一個個的token,直至整句話生成完畢時停止。這裏用到的知識點是tf.while_loop(),如下
#舉例:循環變量a,b,c;f(.),g(.),h(.)是函數
tf.while_loop(
condition,
body,
loop_vars=(init_a,init_b,init_c)
)
def condition(unuesd_a,b,unuesd_c):
# 即使a,c變量用不到,也要寫在condition函數的參數中
return b>1 # 返回bool類型的值
def body(a,b,c):
next_a=f(a)
next_b=g(b)
next_c=h(c)
return next_a,next_b,next_c
這個函數的功能是:當condition函數return的結果爲True時,進入循環體body()進行計算,並返回更新後的變量的值;如果condition函數return的結果爲False,循環結束。
目錄
第一板斧 負責決定當前時間步的輸出(sample函數)和下一時間步的輸入(next_inputs函數)。
第三板斧負責模擬RNN在每個時間步的情況,並在合適的時刻(比如遇到eos或者達到指定的最大長度)停止。
1 tensorflow 版本
import tensorflow as tf
tf.__version__ # tensorflow版本爲1.12.0
2 動態RNN實現“三板斧”
如果定製自己需要的動態RNN,只需要修改三板斧中的對應函數,即可將自己的想法融入tf框架中,無需自己從0實現一個動態RNN,原因有二,一是方便,二是自己從0實現的不一定比tf的寫的好emm
第一板斧 負責決定當前時間步的輸出(sample函數)和下一時間步的輸入(next_inputs函數)。
helper = tf.contrib.seq2seq.TrainingHelper(inputs=input_embed,…) # input_embed是rnn輸入字符的embedding
上述函數在文件helper.py中,是專用於訓練時候的helper,除此之外,helper.py中還有適用於inference時候的helper,一起來看看源碼(下面爲源碼的重要部分截取,不是完整的helper.py文件,下同),關鍵是sample函數和next_inputs函數的實現。
# helper.py中所有的class,除了用於訓練的TrainingHelper,
# 還有一些用於推斷時候的helper,甚至可以自定義,即CustomHelper。
# 對於每個helper,關鍵在於sample函數和next_inputs函數的實現。
__all__ = [
"Helper",
"TrainingHelper",
"GreedyEmbeddingHelper",
"SampleEmbeddingHelper",
"CustomHelper",
"ScheduledEmbeddingTrainingHelper",
"ScheduledOutputTrainingHelper",
"InferenceHelper",
]
#訓練階段 以TrainingHelper爲例進行分析
class TrainingHelper(Helper):
def __init__(self, inputs, sequence_length, time_major=False, name=None):
initial部分的源碼不進行粘貼
def sample(self, time, outputs, name=None, **unused_kwargs):
# 採樣得到當前時間步的輸出token
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32) # 取概率最大的token作爲輸出
return sample_ids
def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
# 若rnn未finished,則取當前時間步輸出真值作爲下一步的輸入。因爲訓練階段是有標籤的
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)
# 推斷階段的helper以GreedyEmbeddingHelper爲例進行分析
class GreedyEmbeddingHelper(Helper):
def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
return sample_ids
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
# 因爲是推斷階段,所以把當前時間步的輸出的預測值作爲下一步的輸入。
# sample_ids是token id,所以用_embedding_fn函數得到其embedding後再作爲next_inputs
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
第二板斧 負責執行一個時間步(step函數),
調用cell得到該時間步的輸出概率,調用helper得到該時間步的輸出token id和下一步的輸入token的embedding。
decoder = tf.contrib.seq2seq.BasicDecoder(cell=rnn_cell, helper=helper,…)
上述函數在basic_decoder.py中,BasicDecoder類繼承於Decoder類(Decoder類在decoder.py文件中,和dynamic_decode函數在一個文件中),實現了Decoder類中的step函數。
其他的Decoder比如BeamSearchDecoder也繼承於Decoder類,實現了Decoder類中的step函數。
所以,如果想自己實現一個decoder的話,繼承Decoder類並實現step函數即可。
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
可以看出,step函數中調用了cell來得到當前時間步的輸出,這裏的cell是rnn_cell,定義了RNN的結構,所以瞭解cell的輸入與輸出是什麼很重要,這樣才能正確調用。如果想了解常用rnn_cell的結構,可以閱讀Tensorflow RNN結構 解讀
第三板斧負責模擬RNN在每個時間步的情況,並在合適的時刻(比如遇到eos或者達到指定的最大長度)停止。
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,…)
上述函數位於decoder.py文件中,dynamic_decode是一個loop(循環)來得到全部時間步的情況,每個時間步都調用decoder。
"""
condition是判斷是否停止的條件,body中會調用decoder.step()來得到相關信息。
loop_vars是在循環中不斷變化更新的變量,這些變量需要輸入到body函數中,
在body函數中計算更新並return,以作爲下一個循環body函數的輸入。
這裏res的內容其實就是body函數返回的內容,也就是loop_vars的值。
"""
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths,):
return math_ops.logical_not(math_ops.reduce_all(finished))
res = control_flow_ops.while_loop(
condition,
body,
loop_vars=(
initial_time,
initial_outputs_ta,
initial_state,
initial_inputs,
initial_finished,
initial_sequence_lengths,
),
parallel_iterations=parallel_iterations,
maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
所以,如果想定製自己需要的動態RNN,要想清楚loop_vars有哪些,然後寫到loop_vars中哦~
我們來分析下停止條件condition()函數,finished的shape是(batch_size,),math_ops.reduce_all()是Computes the "logical and" of elements across dimensions of a tensor,math_ops.logical_not()是邏輯非。若要循環停止,則math_ops.logical_not()爲False,即math_ops.reduce_all()爲True,即finished中每個元素都爲True。我們知道,循環結束的標誌是句子解碼結束,每個finished中第i個元素爲True,代表該batch中第i個句子已經解碼結束,所以,結論是,只有當batch中所有句子都解碼結束,纔會停止循環,即一個 batch 中的句子長度不相同時,得到的 dynamic_length 應該是某個 batch 中最長的一句的長度。