本文主要是對tensorflow中lstm模型中的c,h進行解析。rnn_cell_impl.py
1.關於RNN模型
在rnn_cell_impl.py的tensorflow源碼中,關於RNN部分實現的類主要是BasicRNNCell,
首先在build函數中,定義了兩個變量_kernel和_bias。
其中_num_untis表示RNN cell 的untis數目。
所以在call函數中,hidden_state的更新如下所示:
從上面中可以看出,RNN首先將input與上一個state連接,然後與在build函數中定義的_kernal變量點乘,最後加上偏置項。
2. 關於LSTM模型
主要看BasicLSTMCell這個類,在build函數中,定義了兩個參數_kernel與_bias
與RNN不同,參數_kernal與_bias的列都是_num_units的四倍,主要是因爲後面要分成四個部分,分別爲i,j,f,o。
因此,在call函數中,
在call函數中,i,j,f,o可以分別表示爲:
所以,在上面的圖中,最上面的橫線表示C,最小面的橫線表示h。
3. 關於GRU模型
在GRU模型中的build函數中,可以看到定義了四個參數:
因此,在call函數中,
從下面的圖中可以看出,zt爲u,r表示rt,
從tensorflow的源碼來看,上面的公式中ht的求解有問題,所以參考維基百科,得到下面的公式: