利用fasttext对短信内容进行分类

一.样本说明:

  • 1405506条记录,其中逾期的为486996条记录,非逾期的为486996条
  • 包含两个字段 tag (标识是否逾期) ,message(短信内容)
  • 实际训练样本(non_overdue: 641065, overdue: 340783) 
  • 实际测试样本(non_overdue: 274660, overdue: 146132)
  • 目标:根据短信内容,预测类别是否逾期

二.数据预处理:

  • 收集停用词库,这里( https://github.com/goto456/stopwords)包含了哈工大停用词、中文停用词、四川大学智能实验室停用词、百度停用词;但似乎也不是很全,呵呵!
  • 生成训练集合和测试集合:我采用了python中random.shuffle的函数对数据重新排序,然后选取前70%的作为训练集,后30%作为测试集; 

三.具体代码实现:

  • 导入相关包
import re
import jieba
import os
import fasttext
import random
  • 生成停用词表
# 生成停用词表
def create_stoplist(dir_path):
    # dir_path为停用词文件存储的目录路径
    stoplist=[]
    for ele in os.listdir(dir_path):
        file=dir_path+ele
        with open(file,'rb') as f:
            stoplist.extend(f.readlines())
    return list(set(stoplist))
  • 读取样本数据并进行shuffle
file='/Users/hqh/Desktop/savedata/message_2cls_subsample'
data=open(file,'r').readlines()
random.shuffle(data)
  • 数据处理并转为成fasttext所需的输入形式
tranform_data=[] 
message_dict={}
for ele in data:
    message=ele.split('^')[0].strip("\n")
    tag=ele.split('^')[1].strip("\n")
    if len(message)>0:  # 可能存在message为空的数据,所以加这个限制条件
        message=re.sub('\【.*\】','',message).strip('\n')  #去除平台名称、空格
        message_wc=list(jieba.cut(message))     #采用jieba分词
        message_wc=" ".join(list(set(message_wc)-set(stoplist)))  #去除停用词
        label='__label__'+tag    # 整理fasttext的标签格式,__label__后面紧跟着类
        line=label+' '+message_wc+'\n'
        message_dict[message_wc]=ele+'^'+tag  # 这里保存下切词,停用词处理后的message与原始的message的关系
        tranform_data.append(line)
  • 测试、训练集的划分
n=len(tranform_data)
train_data=tranform_data[1:int(n*0.7)]
test_data=tranform_data[int(n*0.7):]

with open('/Users/hqh/Desktop/train','w') as f:
    for ele in train_data:
        f.write(ele)

with open('/Users/hqh/Desktop/test','w') as f:
    for ele in test_data:
        f.write(ele)
  • 建立模型
classifier = fasttext.supervised('/Users/hqh/Desktop/train', 'model',label_prefix='__label__')
result=classifier.test("/Users/hqh/Desktop/test")
print(result.precision)
print(result.recall)
print(result.nexamples)
  • 手动找出识别错误的case

errors=[]  # 用于记录错误的案例

with open('/Users/hqh/Desktop/errors','w') as f:
    for ele in test_data:
        info=ele.strip("\n").split(" ")
        try:
            key=" ".join(info[1:])
            message=[key]  # 从1开始是因为第0个为label
            label=classifier.predict(message)[0][0] # 预测的标签
            tag=ele.strip("\n").split(" ")[0].split("__label__")[1] # 实际的标签
            if label != tag:
                f.write(message_dict[key]+"\n")
                errors.append(message_dict[key])
        except IndexError as e:
            print('error')

print(len(errors)/len(test_data))
  • 具体结果

准确率和召回率都是 0.9997219548711368
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章