seq2seq的實現方式(1)

應用場景

seq2seq是自然語言處理應用中的常用模型,一般的機器翻譯,文本摘要,對話生成(雖然之前實現過基於語言模型+關鍵詞的生成方式,但這纔是正道),文本摘要等任務。更高級的模型也是從基礎的模型進行迭代的模型架構相對統一。

其具體的模型原理就不講了,有很多博客已經有很好的說明,在這裏只是趁着週末更新一下seq2seq在機器翻譯方面的實驗,更新上來供同行們參考。

嗯,seq2seq有幾種模式:
(1)最簡單的一種是Encoder的隱層向量複製後直接作爲decoder的輸入,也就是decoder對不需要序列輸入。
(2)在一個是Encoder的隱層向量作爲decoder的初始化,並且decoder的有輸入序列,並且和輸出序列錯位,用於啓發。
(3)就是把(1)和(2)結合起來,即要參考Decoder輸入序列,又要參考Encoder的最後的隱層向量,爲啓發獲取更多的信息。
(4)因爲在翻譯每一個詞的時候,輸入端各個詞的貢獻其實是不一樣的,所以用Encoder的最後隱層沒有多樣性,所以改用attention替換(3)中的Encoder隱層向量。
幾種方式一脈相承,逐步深化。

這裏實現了(2)的方法。

    def build_model(self):
        
        encoder_input = layers.Input(shape=(self.input_seq_len,))
        encoder_embeding = layers.Embedding(input_dim=len(self.en_word_id_dict),
                                            output_dim=self.encode_embeding_len,
                                            mask_zero=True
                                            )(encoder_input)
        encoder_lstm, state_h, state_c = layers.LSTM(units=self.encode_embeding_len,
                                                     return_state=True)(encoder_embeding)

        encoder_state = [state_h, state_c]

        decoder_input = layers.Input(shape=(self.output_seq_len,))
        decoder_embeding = layers.Embedding(input_dim=len(self.ch_word_id_dict),
                                            output_dim=self.decode_embeding_len,
                                            mask_zero=True
                                            )(decoder_input)
        decoder_lstm, _, _ = layers.LSTM(units=self.encode_embeding_len,
                                         return_state=True,
                                         return_sequences=True)(decoder_embeding, initial_state=encoder_state)
        decoder_out = layers.Dense(len(self.ch_word_id_dict), activation="softmax")(decoder_lstm)

        model = Model([encoder_input, decoder_input], decoder_out)
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
        # model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy')
        model.summary()
        return model


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章