FlyAI搜狗新聞文本分類項目
1、項目簡介
搜狗新聞文本分類項目是NLP的入門項目,本文主要介紹使用keras框架通過構建CNN+BiGRU網絡實現在搜狗新聞文本數據集上91+的準確率。
2、數據集來源
該數據集來自若干新聞站點2012年6月—7月期間國內,國際,體育,社會,娛樂等18個頻道的新聞數據。根據新聞正文內容分析新聞的類別數據集官網鏈接: The SogouTCE. 該數據集樣例格式如下所示:
text,label “福田歐曼·歐康杯”第六屆全國卡車大賽廈門分站賽卡車寶貝進行才藝表演,汽車
在 FlyAI競賽平臺上 提供了超詳細的參考代碼,我們可以通過參加搜狗新聞文本分類預測練習賽進行進一步學習和優化。下面是準確率91+的代碼實現其主要部分代碼實現如下:
3、代碼實現
3.1、算法流程及實現
算法流程主要分爲以下四個部分進行介紹:
-
數據加載
-
構建網絡
-
模型訓練
數據加載
在FlyAI的項目中封裝了Dataset類,可以實現對數據的一些基本操作,比如加載批量訓練數據next_train_batch()和校驗數據next_validation_batch()、獲取全量數據get_all_data()、獲取訓練集數據量get_train_length() 和獲取校驗集數據量get_validation_length()等。具體使用方法如下:
# 引入Dataset類 from flyai.dataset import Dataset #創建Dataset類的實例 dataset = Dataset(epochs=10, batch=128) # dataset.get_step()返回訓練總步長 # 加載processor.py中處理好的數據,一次性加載所有的train和val數據 x_train, y_train, x_val, y_val = dataset.get_all_processor_data() print('Load data done!')
對每條新聞數據的讀取和處理是在processor.py文件中完成。具體實現如下:
from flyai.processor.base import Base import re import jieba import numpy as np from create_dict import load_dict, load_label_dict MAX_LEN = 128 # 最大詞數,分析數據集後給出的合理值 class Processor(Base): def __init__(self): super(Processor, self).__init__() # 加載字典文件 self.w2idx, self.idx2w = load_dict() self.l2idx, self.idx2l = load_label_dict() # 該參數[text]需要與app.yaml的Model的input-->columns->name 一一對應 def input_x(self, text): ''' 參數爲csv中作爲輸入x的一條數據,該方法會被Dataset多次調用 ''' # 對數據進行清洗 text_line = re.sub(r',{1,}', ',', re.sub(u"([^\u4e00-\u9fa5\u0030-\u0039\u0041-\u005a\u0061-\u007a“”.?!(()){}【\[\]】:~,。…\"])",",", text)) # print('text_line', text_line) terms = jieba.cut(text_line, cut_all=False) truncate_terms = [] for term in terms: truncate_terms.append(term) # 如果text詞數超過MAX_LEN提前跳出循環 if len(truncate_terms) >= MAX_LEN: break index_list = [self.w2idx[term] if term in self.w2idx else self.w2idx['_UNK_'] for term in truncate_terms] if len(index_list) < MAX_LEN: index_list = index_list+[self.w2idx['_PAD_']]*(MAX_LEN-len(index_list)) return index_list # 該參數[label]需要與app.yaml的Model的output-->columns->name 一一對應 def input_y(self, label): ''' 參數爲csv中作爲輸入y的一條數據,該方法會被Dataset多次調用 ''' y = np.zeros(len(self.l2idx)) if label in self.l2idx: label_index = self.l2idx[label] y[label_index] = 1 return y def output_y(self, data): ''' 驗證時使用,把模型輸出的y轉爲對應的結果 ''' out_y = np.argmax(data) if out_y in self.idx2l: out_y = self.idx2l[out_y] else: out_y = '未知標籤' return out_y
構建網絡
由於是搜狗新聞文本類數據,這裏我們可以使用一維卷積Conv1D + BiGRU來構建網絡,網絡結構如下所示:
# ----------------------------構建網絡---------------------------- # # 數據輸入格式 input_x = Input(shape=(MAX_LEN,), dtype='int32', name='input_x') # embedding層 embedden_seq = Embedding(input_dim=word_vocab, output_dim=embed_dim, input_length=MAX_LEN, name='embed')(input_x) # 第一個卷積層,由於要處理文本數據,這裏我們使用一維卷積Conv1D,conv_dim個卷積核,大小5x1,卷積模式SAME, # 激活函數relu, 最終輸出張量的shape和輸入的一致 conv1 = Conv1D(filters=conv_dim, kernel_size=5, padding='same', activation='relu')(embedden_seq) # 隨機斷開二分之一的網絡連接,防止過擬合 drop1 = SpatialDropout1D(0.5)(conv1) # 壓扁不影響batch_size大小 fl1 = Flatten()(drop1) # 批標準化,對數據進行標準化處理 bn1 = BatchNormalization()(embedden_seq) # 雙向GRU網絡,對數據進行前向後向化處理,更好的獲取上下文信息 bGRU1 = Bidirectional(GRU(rnn_size, activation='selu', return_sequences=True, implementation=1), merge_mode='concat')(bn1) # 隨機斷開四分之一的網絡連接,防止過擬合 drop2 = SpatialDropout1D(0.5)(bGRU1) # 壓扁不影響batch_size大小 fl2 = Flatten()(drop2) # 將CNN網絡和GRU網絡輸出連接起來 cont1 = concatenate([fl1, fl2]) # 添加隱藏層神經元的數量和激活函數 dense1 = Dense(256, activation='relu')(cont1) # 最終經激活函數softmax輸出對應位置的predict概率 predict = Dense(class_num, activation='softmax')(dense1) # 傳入數據 k_model = keras.Model(input_x, predict) # 打印網絡詳情 k_model.summary() # 執行模型編譯,傳入優化器名稱並修改相應參數 k_model.compile(optimizer=keras.optimizers.adam(lr=0.001, decay=1e-6), loss='categorical_crossentropy', metrics=['accuracy', ])
運行summary()方法後輸出的網絡結構如下圖:
模型訓練
這裏我們設置了epoch爲5,batch爲128,採用adam優化器來訓練網絡,EarlyStopping可以加速調參過程。然後通過調用FlyAI提供的train_log方法可以在訓練過程中實時的看到訓練集和驗證集的準確率及損失變化曲線。
# 1. 定義early_stopping規則 early_stopping = EarlyStopping(monitor='val_acc', patience=2) # 2. 將數據fit到網絡,訓練模型,以BATCH個樣本爲一個batch進行迭代,EPOCHS爲要迭代的輪次數 k_model.fit(x_train, y_train, batch_size=args.BATCH, validation_data=(x_val, y_val), epochs=args.EPOCHS, verbose=1, callbacks=[early_stopping])
3.2、最終結果
通過使用自定義CNN網絡結構+雙向GRU網絡的方法,在epoch爲10,batch爲128的條件下使用adam優化器下不斷優化模型參數,使用early_stopping規則在model訓練達到early_stopping條件時提前終止訓練提高model優化效率,最終模型在測試集的準確率達到91+。 該項目的可運行代碼如下:完整代碼鏈接。
參考鏈接: