Tensorflow RNN源碼理解

一、閱讀源碼

這個是TensorflowRNN源碼,官方註釋解釋的比較清楚:

 

RNNCell是一個抽象類,我們看下下它的屬性:

 

我們可以發現這裏用到的是Python內置的@property裝飾器就是負責把一個方法變成屬性調用的,很像C#中的屬性、字段的那種概念。State_sizeOutput_size規定了隱層的大小和輸出張量的大小。

 

下面是重要的__call__方法,有點像USRP中的work()或者general_work()的功能。這裏我們注意到輸入的參數有Inputs,State,這裏其實就是指輸入層和隱層了。但是這裏有規定Inputs的格式爲(batch_size,input_size,State的格式爲(batch_zie,state_size,這很容易理解,因爲我們進行訓練數據會分成很多batch。與普通的神經網絡結構一樣,輸入層、隱層、輸出層的size並沒有關係,視應用場景而定。

還有一個總要的方法是初始化方法:

BasicRNNCellGRUCellBasicLSTMCellLSTMCell都是繼承於LayerRNNCell,而LayerRNNCell繼承於上面講的抽象類RNNCell,這就是TensorflowRNN的繼承關係。

這裏不做過多介紹,但是有意思的一點是這裏:

 

其實BasicRNNCell的輸出、隱層狀態是一樣的。而BasicLSTMCell的隱層狀態和輸出是不一樣的。New_stae = LSTMStateTuple(new_c,new_h)

 

同樣RNN的隱層也可以構建多層MultiRNNCell

根據源碼可知:

 

隱層狀態的返回值是元組(tuple)類型

最重要的一個類Dynamic_rnn(batch_size,time_steps,input_size),參數很好理解,但是需要強調的是State是最後一步的隱藏狀態,形狀是(batch_size,cell.State_size),time_steps是調用RNNCell抽象類中__call__()函數的次數,Output是所有steps的輸出。Time_major=True的情況下將Output的格式中batch_szietime_steps位置交換。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章