一.樣本說明:
- 共1405506條記錄,其中逾期的爲486996條記錄,非逾期的爲486996條
- 包含兩個字段 tag (標識是否逾期) ,message(短信內容)
- 實際訓練樣本(non_overdue: 641065, overdue: 340783)
- 實際測試樣本(non_overdue: 274660, overdue: 146132)
- 目標:根據短信內容,預測類別是否逾期
二.數據預處理:
- 收集停用詞庫,這裏( https://github.com/goto456/stopwords)包含了哈工大停用詞、中文停用詞、四川大學智能實驗室停用詞、百度停用詞;但似乎也不是很全,呵呵!
- 生成訓練集合和測試集合:我採用了python中random.shuffle的函數對數據重新排序,然後選取前70%的作爲訓練集,後30%作爲測試集;
三.具體代碼實現:
- 導入相關包
import re
import jieba
import os
import fasttext
import random
- 生成停用詞表
# 生成停用詞表
def create_stoplist(dir_path):
# dir_path爲停用詞文件存儲的目錄路徑
stoplist=[]
for ele in os.listdir(dir_path):
file=dir_path+ele
with open(file,'rb') as f:
stoplist.extend(f.readlines())
return list(set(stoplist))
- 讀取樣本數據並進行shuffle
file='/Users/hqh/Desktop/savedata/message_2cls_subsample'
data=open(file,'r').readlines()
random.shuffle(data)
- 數據處理並轉爲成fasttext所需的輸入形式
tranform_data=[]
message_dict={}
for ele in data:
message=ele.split('^')[0].strip("\n")
tag=ele.split('^')[1].strip("\n")
if len(message)>0: # 可能存在message爲空的數據,所以加這個限制條件
message=re.sub('\【.*\】','',message).strip('\n') #去除平臺名稱、空格
message_wc=list(jieba.cut(message)) #採用jieba分詞
message_wc=" ".join(list(set(message_wc)-set(stoplist))) #去除停用詞
label='__label__'+tag # 整理fasttext的標籤格式,__label__後面緊跟着類
line=label+' '+message_wc+'\n'
message_dict[message_wc]=ele+'^'+tag # 這裏保存下切詞,停用詞處理後的message與原始的message的關係
tranform_data.append(line)
- 測試、訓練集的劃分
n=len(tranform_data)
train_data=tranform_data[1:int(n*0.7)]
test_data=tranform_data[int(n*0.7):]
with open('/Users/hqh/Desktop/train','w') as f:
for ele in train_data:
f.write(ele)
with open('/Users/hqh/Desktop/test','w') as f:
for ele in test_data:
f.write(ele)
- 建立模型
classifier = fasttext.supervised('/Users/hqh/Desktop/train', 'model',label_prefix='__label__')
result=classifier.test("/Users/hqh/Desktop/test")
print(result.precision)
print(result.recall)
print(result.nexamples)
- 手動找出識別錯誤的case
errors=[] # 用於記錄錯誤的案例
with open('/Users/hqh/Desktop/errors','w') as f:
for ele in test_data:
info=ele.strip("\n").split(" ")
try:
key=" ".join(info[1:])
message=[key] # 從1開始是因爲第0個爲label
label=classifier.predict(message)[0][0] # 預測的標籤
tag=ele.strip("\n").split(" ")[0].split("__label__")[1] # 實際的標籤
if label != tag:
f.write(message_dict[key]+"\n")
errors.append(message_dict[key])
except IndexError as e:
print('error')
print(len(errors)/len(test_data))
- 具體結果
準確率和召回率都是 0.9997219548711368