在之前的 使用gensim和sklearn搭建一個文本分類器(一):流程概述 中,敘述了一個使用lsi來對文本進行向量化,再使用線性核svm進行分類的文本分類器。在這篇文章中,沿着之前的思路,提供了該文本分類器的具體實現。該分類器對之前的流程進行了適當的整合,現在有五個主要流程:
- 生成詞典
- 生成tfidf向量
- 生成lsi向量
- 分類器參數訓練
- 對新文本進行分類。
前4個步驟可以看做是分類器的訓練過程,而第五個階段,則是使用訓練得到的參數對新文本進行分類。
所有步驟都寫在main函數中,並使用標記隔開以保證可讀性。由於這5個階段相對獨立,因此每一階段會將階段成果存入磁盤,這樣後一階段會直接從磁盤讀取之前的成果,不再重複執行上一階段的代碼。如果想重新運行某個階段,只要把對應的結果刪除即可
這裏要注意,在運行程序前要先指定一些路徑,其中最重要的要屬path_doc_root和path_tmp了。前者存放爲存放文本文件的目錄,後者存放運行程序時生成的中間文件。對於前者,我這裏的值是
path_doc_root = '/media/multiangle/F/DataSet/THUCNews/THUCNewsTotal'
也就是說,所有的文本資料都應該按類放於該文件夾下,每個類都是一個文件夾, 每個類別文件夾下都存放了許多文檔,一個文檔是一個txt文件。如下圖所示。我這裏用的是THUNews的文件集
具體的代碼如下所示
from gensim import corpora, models
from scipy.sparse import csr_matrix
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import svm
import numpy as np
import os,re,time,logging
import jieba
import pickle as pkl
# logging.basicConfig(level=logging.WARNING,
# format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
# datefmt='%a, %d %b %Y %H:%M:%S',
# )
class loadFolders(object): # 迭代器
def __init__(self,par_path):
self.par_path = par_path
def __iter__(self):
for file in os.listdir(self.par_path):
file_abspath = os.path.join(self.par_path, file)
if os.path.isdir(file_abspath): # if file is a folder
yield file_abspath
class loadFiles(object):
def __init__(self,par_path):
self.par_path = par_path
def __iter__(self):
folders = loadFolders(self.par_path)
for folder in folders: # level directory
catg = folder.split(os.sep)[-1]
for file in os.listdir(folder): # secondary directory
file_path = os.path.join(folder,file)
if os.path.isfile(file_path):
this_file = open(file_path,'rb')
content = this_file.read().decode('utf8')
yield catg,content
this_file.close()
def convert_doc_to_wordlist(str_doc,cut_all):
sent_list = str_doc.split('\n')
sent_list = map(rm_char, sent_list) # 去掉一些字符,例如\u3000
word_2dlist = [rm_tokens(jieba.cut(part,cut_all=cut_all)) for part in sent_list] # 分詞
word_list = sum(word_2dlist,[])
return word_list
def rm_tokens(words): # 去掉一些停用次和數字
words_list = list(words)
stop_words = get_stop_words()
for i in range(words_list.__len__())[::-1]:
if words_list[i] in stop_words: # 去除停用詞
words_list.pop(i)
elif words_list[i].isdigit():
words_list.pop(i)
return words_list
def get_stop_words(path='/home/multiangle/coding/python/PyNLP/static/stop_words.txt'):
file = open(path,'rb').read().decode('utf8').split('\n')
return set(file)
def rm_char(text):
text = re.sub('\u3000','',text)
return text
def svm_classify(train_set,train_tag,test_set,test_tag):
clf = svm.LinearSVC()
clf_res = clf.fit(train_set,train_tag)
train_pred = clf_res.predict(train_set)
test_pred = clf_res.predict(test_set)
train_err_num, train_err_ratio = checkPred(train_tag, train_pred)
test_err_num, test_err_ratio = checkPred(test_tag, test_pred)
print('=== 分類訓練完畢,分類結果如下 ===')
print('訓練集誤差: {e}'.format(e=train_err_ratio))
print('檢驗集誤差: {e}'.format(e=test_err_ratio))
return clf_res
def checkPred(data_tag, data_pred):
if data_tag.__len__() != data_pred.__len__():
raise RuntimeError('The length of data tag and data pred should be the same')
err_count = 0
for i in range(data_tag.__len__()):
if data_tag[i]!=data_pred[i]:
err_count += 1
err_ratio = err_count / data_tag.__len__()
return [err_count, err_ratio]
if __name__=='__main__':
path_doc_root = '/media/multiangle/F/DataSet/THUCNews/THUCNewsTotal' # 根目錄 即存放按類分類好的問本紀
path_tmp = '/media/multiangle/F/DataSet/THUCNews/tmp' # 存放中間結果的位置
path_dictionary = os.path.join(path_tmp, 'THUNews.dict')
path_tmp_tfidf = os.path.join(path_tmp, 'tfidf_corpus')
path_tmp_lsi = os.path.join(path_tmp, 'lsi_corpus')
path_tmp_lsimodel = os.path.join(path_tmp, 'lsi_model.pkl')
path_tmp_predictor = os.path.join(path_tmp, 'predictor.pkl')
n = 10 # n 表示抽樣率, n抽1
dictionary = None
corpus_tfidf = None
corpus_lsi = None
lsi_model = None
predictor = None
if not os.path.exists(path_tmp):
os.makedirs(path_tmp)
# # ===================================================================
# # # # 第一階段, 遍歷文檔,生成詞典,並去掉頻率較少的項
# 如果指定的位置沒有詞典,則重新生成一個。如果有,則跳過該階段
if not os.path.exists(path_dictionary):
print('=== 未檢測到有詞典存在,開始遍歷生成詞典 ===')
dictionary = corpora.Dictionary()
files = loadFiles(path_doc_root)
for i,msg in enumerate(files):
if i%n==0:
catg = msg[0]
file = msg[1]
file = convert_doc_to_wordlist(file,cut_all=False)
dictionary.add_documents([file])
if int(i/n)%1000==0:
print('{t} *** {i} \t docs has been dealed'
.format(i=i,t=time.strftime('%Y-%m-%d %H:%M:%S',time.localtime())))
# 去掉詞典中出現次數過少的
small_freq_ids = [tokenid for tokenid, docfreq in dictionary.dfs.items() if docfreq < 5 ]
dictionary.filter_tokens(small_freq_ids)
dictionary.compactify()
dictionary.save(path_dictionary)
print('=== 詞典已經生成 ===')
else:
print('=== 檢測到詞典已經存在,跳過該階段 ===')
# # ===================================================================
# # # # 第二階段, 開始將文檔轉化成tfidf
if not os.path.exists(path_tmp_tfidf):
print('=== 未檢測到有tfidf文件夾存在,開始生成tfidf向量 ===')
# 如果指定的位置沒有tfidf文檔,則生成一個。如果有,則跳過該階段
if not dictionary: # 如果跳過了第一階段,則從指定位置讀取詞典
dictionary = corpora.Dictionary.load(path_dictionary)
os.makedirs(path_tmp_tfidf)
files = loadFiles(path_doc_root)
tfidf_model = models.TfidfModel(dictionary=dictionary)
corpus_tfidf = {}
for i, msg in enumerate(files):
if i%n==0:
catg = msg[0]
file = msg[1]
word_list = convert_doc_to_wordlist(file,cut_all=False)
file_bow = dictionary.doc2bow(word_list)
file_tfidf = tfidf_model[file_bow]
tmp = corpus_tfidf.get(catg,[])
tmp.append(file_tfidf)
if tmp.__len__()==1:
corpus_tfidf[catg] = tmp
if i%10000==0:
print('{i} files is dealed'.format(i=i))
# 將tfidf中間結果儲存起來
catgs = list(corpus_tfidf.keys())
for catg in catgs:
corpora.MmCorpus.serialize('{f}{s}{c}.mm'.format(f=path_tmp_tfidf,s=os.sep,c=catg),
corpus_tfidf.get(catg),
id2word = dictionary
)
print('catg {c} has been transformed into tfidf vector'.format(c=catg))
print('=== tfidf向量已經生成 ===')
else:
print('=== 檢測到tfidf向量已經生成,跳過該階段 ===')
# # ===================================================================
# # # # 第三階段, 開始將tfidf轉化成lsi
if not os.path.exists(path_tmp_lsi):
print('=== 未檢測到有lsi文件夾存在,開始生成lsi向量 ===')
if not dictionary:
dictionary = corpora.Dictionary.load(path_dictionary)
if not corpus_tfidf: # 如果跳過了第二階段,則從指定位置讀取tfidf文檔
print('--- 未檢測到tfidf文檔,開始從磁盤中讀取 ---')
# 從對應文件夾中讀取所有類別
files = os.listdir(path_tmp_tfidf)
catg_list = []
for file in files:
t = file.split('.')[0]
if t not in catg_list:
catg_list.append(t)
# 從磁盤中讀取corpus
corpus_tfidf = {}
for catg in catg_list:
path = '{f}{s}{c}.mm'.format(f=path_tmp_tfidf,s=os.sep,c=catg)
corpus = corpora.MmCorpus(path)
corpus_tfidf[catg] = corpus
print('--- tfidf文檔讀取完畢,開始轉化成lsi向量 ---')
# 生成lsi model
os.makedirs(path_tmp_lsi)
corpus_tfidf_total = []
catgs = list(corpus_tfidf.keys())
for catg in catgs:
tmp = corpus_tfidf.get(catg)
corpus_tfidf_total += tmp
lsi_model = models.LsiModel(corpus = corpus_tfidf_total, id2word=dictionary, num_topics=50)
# 將lsi模型存儲到磁盤上
lsi_file = open(path_tmp_lsimodel,'wb')
pkl.dump(lsi_model, lsi_file)
lsi_file.close()
del corpus_tfidf_total # lsi model已經生成,釋放變量空間
print('--- lsi模型已經生成 ---')
# 生成corpus of lsi, 並逐步去掉 corpus of tfidf
corpus_lsi = {}
for catg in catgs:
corpu = [lsi_model[doc] for doc in corpus_tfidf.get(catg)]
corpus_lsi[catg] = corpu
corpus_tfidf.pop(catg)
corpora.MmCorpus.serialize('{f}{s}{c}.mm'.format(f=path_tmp_lsi,s=os.sep,c=catg),
corpu,
id2word=dictionary)
print('=== lsi向量已經生成 ===')
else:
print('=== 檢測到lsi向量已經生成,跳過該階段 ===')
# # ===================================================================
# # # # 第四階段, 分類
if not os.path.exists(path_tmp_predictor):
print('=== 未檢測到判斷器存在,開始進行分類過程 ===')
if not corpus_lsi: # 如果跳過了第三階段
print('--- 未檢測到lsi文檔,開始從磁盤中讀取 ---')
files = os.listdir(path_tmp_lsi)
catg_list = []
for file in files:
t = file.split('.')[0]
if t not in catg_list:
catg_list.append(t)
# 從磁盤中讀取corpus
corpus_lsi = {}
for catg in catg_list:
path = '{f}{s}{c}.mm'.format(f=path_tmp_lsi,s=os.sep,c=catg)
corpus = corpora.MmCorpus(path)
corpus_lsi[catg] = corpus
print('--- lsi文檔讀取完畢,開始進行分類 ---')
tag_list = []
doc_num_list = []
corpus_lsi_total = []
catg_list = []
files = os.listdir(path_tmp_lsi)
for file in files:
t = file.split('.')[0]
if t not in catg_list:
catg_list.append(t)
for count,catg in enumerate(catg_list):
tmp = corpus_lsi[catg]
tag_list += [count]*tmp.__len__()
doc_num_list.append(tmp.__len__())
corpus_lsi_total += tmp
corpus_lsi.pop(catg)
# 將gensim中的mm表示轉化成numpy矩陣表示
data = []
rows = []
cols = []
line_count = 0
for line in corpus_lsi_total:
for elem in line:
rows.append(line_count)
cols.append(elem[0])
data.append(elem[1])
line_count += 1
lsi_matrix = csr_matrix((data,(rows,cols))).toarray()
# 生成訓練集和測試集
rarray=np.random.random(size=line_count)
train_set = []
train_tag = []
test_set = []
test_tag = []
for i in range(line_count):
if rarray[i]<0.8:
train_set.append(lsi_matrix[i,:])
train_tag.append(tag_list[i])
else:
test_set.append(lsi_matrix[i,:])
test_tag.append(tag_list[i])
# 生成分類器
predictor = svm_classify(train_set,train_tag,test_set,test_tag)
x = open(path_tmp_predictor,'wb')
pkl.dump(predictor, x)
x.close()
else:
print('=== 檢測到分類器已經生成,跳過該階段 ===')
# # ===================================================================
# # # # 第五階段, 對新文本進行判斷
if not dictionary:
dictionary = corpora.Dictionary.load(path_dictionary)
if not lsi_model:
lsi_file = open(path_tmp_lsimodel,'rb')
lsi_model = pkl.load(lsi_file)
lsi_file.close()
if not predictor:
x = open(path_tmp_predictor,'rb')
predictor = pkl.load(x)
x.close()
files = os.listdir(path_tmp_lsi)
catg_list = []
for file in files:
t = file.split('.')[0]
if t not in catg_list:
catg_list.append(t)
demo_doc = """
這次大選讓兩黨的精英都摸不着頭腦。以媒體專家的傳統觀點來看,要選總統首先要避免失言,避免說出一些“offensive”的話。希拉里,羅姆尼,都是按這個方法操作的。羅姆尼上次的47%言論是在一個私人場合被偷錄下來的,不是他有意公開發表的。今年希拉里更是從來沒有召開過新聞發佈會。
川普這種肆無忌憚的發言方式,在傳統觀點看來等於自殺。
"""
print("原文本內容爲:")
print(demo_doc)
demo_doc = list(jieba.cut(demo_doc,cut_all=False))
demo_bow = dictionary.doc2bow(demo_doc)
tfidf_model = models.TfidfModel(dictionary=dictionary)
demo_tfidf = tfidf_model[demo_bow]
demo_lsi = lsi_model[demo_tfidf]
data = []
cols = []
rows = []
for item in demo_lsi:
data.append(item[1])
cols.append(item[0])
rows.append(0)
demo_matrix = csr_matrix((data,(rows,cols))).toarray()
x = predictor.predict(demo_matrix)
print('分類結果爲:{x}'.format(x=catg_list[x[0]]))