利用fasttext對短信內容進行分類

一.樣本說明:

  • 1405506條記錄,其中逾期的爲486996條記錄,非逾期的爲486996條
  • 包含兩個字段 tag (標識是否逾期) ,message(短信內容)
  • 實際訓練樣本(non_overdue: 641065, overdue: 340783) 
  • 實際測試樣本(non_overdue: 274660, overdue: 146132)
  • 目標:根據短信內容,預測類別是否逾期

二.數據預處理:

  • 收集停用詞庫,這裏( https://github.com/goto456/stopwords)包含了哈工大停用詞、中文停用詞、四川大學智能實驗室停用詞、百度停用詞;但似乎也不是很全,呵呵!
  • 生成訓練集合和測試集合:我採用了python中random.shuffle的函數對數據重新排序,然後選取前70%的作爲訓練集,後30%作爲測試集; 

三.具體代碼實現:

  • 導入相關包
import re
import jieba
import os
import fasttext
import random
  • 生成停用詞表
# 生成停用詞表
def create_stoplist(dir_path):
    # dir_path爲停用詞文件存儲的目錄路徑
    stoplist=[]
    for ele in os.listdir(dir_path):
        file=dir_path+ele
        with open(file,'rb') as f:
            stoplist.extend(f.readlines())
    return list(set(stoplist))
  • 讀取樣本數據並進行shuffle
file='/Users/hqh/Desktop/savedata/message_2cls_subsample'
data=open(file,'r').readlines()
random.shuffle(data)
  • 數據處理並轉爲成fasttext所需的輸入形式
tranform_data=[] 
message_dict={}
for ele in data:
    message=ele.split('^')[0].strip("\n")
    tag=ele.split('^')[1].strip("\n")
    if len(message)>0:  # 可能存在message爲空的數據,所以加這個限制條件
        message=re.sub('\【.*\】','',message).strip('\n')  #去除平臺名稱、空格
        message_wc=list(jieba.cut(message))     #採用jieba分詞
        message_wc=" ".join(list(set(message_wc)-set(stoplist)))  #去除停用詞
        label='__label__'+tag    # 整理fasttext的標籤格式,__label__後面緊跟着類
        line=label+' '+message_wc+'\n'
        message_dict[message_wc]=ele+'^'+tag  # 這裏保存下切詞,停用詞處理後的message與原始的message的關係
        tranform_data.append(line)
  • 測試、訓練集的劃分
n=len(tranform_data)
train_data=tranform_data[1:int(n*0.7)]
test_data=tranform_data[int(n*0.7):]

with open('/Users/hqh/Desktop/train','w') as f:
    for ele in train_data:
        f.write(ele)

with open('/Users/hqh/Desktop/test','w') as f:
    for ele in test_data:
        f.write(ele)
  • 建立模型
classifier = fasttext.supervised('/Users/hqh/Desktop/train', 'model',label_prefix='__label__')
result=classifier.test("/Users/hqh/Desktop/test")
print(result.precision)
print(result.recall)
print(result.nexamples)
  • 手動找出識別錯誤的case

errors=[]  # 用於記錄錯誤的案例

with open('/Users/hqh/Desktop/errors','w') as f:
    for ele in test_data:
        info=ele.strip("\n").split(" ")
        try:
            key=" ".join(info[1:])
            message=[key]  # 從1開始是因爲第0個爲label
            label=classifier.predict(message)[0][0] # 預測的標籤
            tag=ele.strip("\n").split(" ")[0].split("__label__")[1] # 實際的標籤
            if label != tag:
                f.write(message_dict[key]+"\n")
                errors.append(message_dict[key])
        except IndexError as e:
            print('error')

print(len(errors)/len(test_data))
  • 具體結果

準確率和召回率都是 0.9997219548711368
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章