關於RNN相關模型-tensorflow源碼理解

本文主要是對tensorflow中lstm模型中的c,h進行解析。rnn_cell_impl.py

1.關於RNN模型

在rnn_cell_impl.py的tensorflow源碼中,關於RNN部分實現的類主要是BasicRNNCell,
首先在build函數中,定義了兩個變量_kernel和_bias。
這裏寫圖片描述
其中_num_untis表示RNN cell 的untis數目。
所以在call函數中,hidden_state的更新如下所示:
這裏寫圖片描述

從上面中可以看出,RNN首先將input與上一個state連接,然後與在build函數中定義的_kernal變量點乘,最後加上偏置項。

2. 關於LSTM模型

主要看BasicLSTMCell這個類,在build函數中,定義了兩個參數_kernel與_bias
這裏寫圖片描述
與RNN不同,參數_kernal與_bias的列都是_num_units的四倍,主要是因爲後面要分成四個部分,分別爲i,j,f,o。
因此,在call函數中,
這裏寫圖片描述
在call函數中,i,j,f,o可以分別表示爲:
這裏寫圖片描述

所以,在上面的圖中,最上面的橫線表示C,最小面的橫線表示h。

3. 關於GRU模型

在GRU模型中的build函數中,可以看到定義了四個參數:
這裏寫圖片描述
因此,在call函數中,

這裏寫圖片描述
從下面的圖中可以看出,zt爲u,r表示rt,
這裏寫圖片描述

從tensorflow的源碼來看,上面的公式中ht的求解有問題,所以參考維基百科,得到下面的公式:
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章