一.样本说明:
- 共1405506条记录,其中逾期的为486996条记录,非逾期的为486996条
- 包含两个字段 tag (标识是否逾期) ,message(短信内容)
- 实际训练样本(non_overdue: 641065, overdue: 340783)
- 实际测试样本(non_overdue: 274660, overdue: 146132)
- 目标:根据短信内容,预测类别是否逾期
二.数据预处理:
- 收集停用词库,这里( https://github.com/goto456/stopwords)包含了哈工大停用词、中文停用词、四川大学智能实验室停用词、百度停用词;但似乎也不是很全,呵呵!
- 生成训练集合和测试集合:我采用了python中random.shuffle的函数对数据重新排序,然后选取前70%的作为训练集,后30%作为测试集;
三.具体代码实现:
- 导入相关包
import re
import jieba
import os
import fasttext
import random
- 生成停用词表
# 生成停用词表
def create_stoplist(dir_path):
# dir_path为停用词文件存储的目录路径
stoplist=[]
for ele in os.listdir(dir_path):
file=dir_path+ele
with open(file,'rb') as f:
stoplist.extend(f.readlines())
return list(set(stoplist))
- 读取样本数据并进行shuffle
file='/Users/hqh/Desktop/savedata/message_2cls_subsample'
data=open(file,'r').readlines()
random.shuffle(data)
- 数据处理并转为成fasttext所需的输入形式
tranform_data=[]
message_dict={}
for ele in data:
message=ele.split('^')[0].strip("\n")
tag=ele.split('^')[1].strip("\n")
if len(message)>0: # 可能存在message为空的数据,所以加这个限制条件
message=re.sub('\【.*\】','',message).strip('\n') #去除平台名称、空格
message_wc=list(jieba.cut(message)) #采用jieba分词
message_wc=" ".join(list(set(message_wc)-set(stoplist))) #去除停用词
label='__label__'+tag # 整理fasttext的标签格式,__label__后面紧跟着类
line=label+' '+message_wc+'\n'
message_dict[message_wc]=ele+'^'+tag # 这里保存下切词,停用词处理后的message与原始的message的关系
tranform_data.append(line)
- 测试、训练集的划分
n=len(tranform_data)
train_data=tranform_data[1:int(n*0.7)]
test_data=tranform_data[int(n*0.7):]
with open('/Users/hqh/Desktop/train','w') as f:
for ele in train_data:
f.write(ele)
with open('/Users/hqh/Desktop/test','w') as f:
for ele in test_data:
f.write(ele)
- 建立模型
classifier = fasttext.supervised('/Users/hqh/Desktop/train', 'model',label_prefix='__label__')
result=classifier.test("/Users/hqh/Desktop/test")
print(result.precision)
print(result.recall)
print(result.nexamples)
- 手动找出识别错误的case
errors=[] # 用于记录错误的案例
with open('/Users/hqh/Desktop/errors','w') as f:
for ele in test_data:
info=ele.strip("\n").split(" ")
try:
key=" ".join(info[1:])
message=[key] # 从1开始是因为第0个为label
label=classifier.predict(message)[0][0] # 预测的标签
tag=ele.strip("\n").split(" ")[0].split("__label__")[1] # 实际的标签
if label != tag:
f.write(message_dict[key]+"\n")
errors.append(message_dict[key])
except IndexError as e:
print('error')
print(len(errors)/len(test_data))
- 具体结果
准确率和召回率都是 0.9997219548711368