功能
- 輸出影評主題;
- 輸出每份評論在各個主題上的權重分佈。
工具
引言
- 在機器學習中,LDA是兩個常用模型的簡稱:線性判別分析(Linear Discriminant Analysis)和隱含狄利克雷分佈(Latent Dirichlet allocation),本篇指的是後者。具體來說,LDA可以解決這樣的問題:如我現在有一批針對“大聖歸來”評論的文本,我想知道大家都在說些什麼,以及每個人在說些什麼。
- spark裏LDA函數的輸入是文本向量化的結果,LDA有兩個輸出:
- 每個主題的主題詞、每個主題詞對此主題的貢獻程度(權重)①;
- 每篇文本在各個主題上的權重分佈 ②。
- 那麼LDA是如何由文本得到主題詞及每篇文檔的主題分佈呢,我們令這批文本一共有3個主題,每個主題用6個詞表示,即每個主題都是6個主題詞。是這樣:
- 隨機初始化:,首先對當前所有文檔中的所有詞都隨機賦予一個主題號(0,1,2),然後統計每個主題下出現每個詞的數量(按照數量從大到小排序,排在前6位的即爲該主題下的主題詞)以及每個文檔下出現各個主題的數量(這就是前面說到的②);
- 迭代:按照Gibbs採樣規則,對每個詞重新賦予主題號,統計主題下出現的詞數量及每個文檔下出現的主題數量;
- 不停的迭代,直到統計的數量不變或者變化較小,停止迭代。
- Spark包含rdd和dataframe兩個接口(機器學習包對應mllib和ml),本文采用的是dataframe接口。
數據集
最優調參效果
- 迭代次數: maxIter=65
- 主題數: k=6
- 優化方法:online
- Alpha:設爲默認值
調參過程
- online,確定迭代次數
- 優化方法爲online下,畫出評價指標(logLikelihood,logPerplexity)和迭代次數的二維圖,其中log likelihood,越大越好,Perplexity評估,越小越好;由下圖可知,最優迭代次數大概在60到70之間,我們這裏令最優迭代次數爲65。
- online,迭代次數爲65,確定主題數
- 優化方法爲online,迭代次數爲65的前提下,將主題數目從2設到9,主觀觀察結果,發現主題數目太少信息提取不全,太多主題分散,主觀觀察後最終定爲6個主題。
- 主題數爲6,online,迭代65次,alpha設爲默認值,即0.16
- 主題數爲6,online,迭代65次,alpha設爲2
- 使用online的過程中,出現了主題非常集中,各個文檔對應的主題分佈也不鮮明,原因是alpha>1,alpha值設錯了,如下所示:
- EM,迭代65次,主題數爲6,確定alpha值
- 這裏沒有測試針對EM的最優迭代次數,設爲65,主題數設爲6,在這種情況下,alpha設置方式要參考以下三點:
- alpha必須>1.0,一般設置爲:(50/k)+1,k爲主題數;
- 評價指標(ogLikelihood,logPerplexity)和alpha的關係圖選擇合適的aplha值;
- 如果alpha設置的過大,各個文檔對應的主題分佈就不鮮明,此時要調小alpha
- 畫出評價指標(logLikelihood,logPerplexity)和alpha的二維圖,參考下圖,alpha可取13,參考公式,alpha可取9.3,然而經測試,alpha=13,9.3,7,5.5時,文檔的主題分佈均不鮮明;當alpha取1.1時,有稍微明顯的主題分佈,不過也有可能是迭代次數設置的不對。
調參規則總結
- 迭代次數: 結合logLikelihood、logPerplexity確定
- 主題數: 太少信息提取不全,太多信息分散,多試幾次
- 優化方法: online、EM
- Alpha
- online: alpha取默認值即可(1.0/k),取值要小於1小於等於0
- 注意:如果使用online的過程中,出現了主題非常集中,各個文檔對應的主題分佈也不鮮明,原因是alpha>1。
- EM: alpha必須>1.0;默認爲:(50/k)+1;根據評價指標(logLikelihood,logPerplexity)和alpha的關係圖選擇
- 注意:如果各個文檔對應的主題分佈不鮮明,此時要調小alpha值。
pyspark腳本
"""
@author:
@contact:
@time:
"""
from __future__ import print_function
from pyspark.sql import SparkSession
import os,ConfigParser,sys
reload(sys)
sys.setdefaultencoding("utf-8")
def configfileParameter(b):
pwd = sys.path[0]
path = os.path.abspath(os.path.join(pwd, os.pardir, os.pardir))
os.chdir(path)
cf = ConfigParser.ConfigParser()
cf.read("/con/configfile.conf")
SPARK_HOME = cf.get("SPARK_HOME", "SPARK_HOME")
return SPARK_HOME
os.environ['SPARK_HOME'] ="/lib/spark"
spark = SparkSession.builder.appName("etl").getOrCreate()
sc = spark.sparkContext
stopwords = sc.textFile("hdfs://stopwords.txt").collect()
def stopword(strArr):
stop_strArr = []
for i in strArr:
if len(i)> 1:
if i.isdigit()!=True:
if i not in stopwords:
stop_strArr.append(i)
return stop_strArr
"""
@author:
@contact:
@file:
@time:
"""
from __future__ import print_function
from pyspark.sql import functions as F
import sys,os,time,jieba,assistFuntion
reload(sys)
sys.setdefaultencoding("utf-8")
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.ml.clustering import LDA
from pyspark.ml.feature import CountVectorizer,IDF
from pyspark.sql.functions import split, explode
from pyspark.sql.window import Window, WindowSpec
print("運行開始時間:" + str(time.localtime(time.time()).tm_hour) + ":" + str(time.localtime(time.time()).tm_min) + "")
SPARK_HOME=assistFuntion.configfileParameter(1)
os.environ['SPARK_HOME'] = SPARK_HOME
spark = SparkSession.builder.appName("lda_test").getOrCreate()
sc = spark.sparkContext
lines = sc.textFile("hdfs://data.txt")
parts = lines.map(lambda l: l.split(" "))
textRdd = parts.map(lambda p: Row(da=p[0], text=p[1]))
textDf = spark.createDataFrame(textRdd)
textDf.createOrReplaceTempView("textDf")
sqlDF = spark.sql("select a,concat_ws(' ', collect_set(text)) as text_group from textDf group by a")
sqlDF.cache()
fenciDict = sc.textFile("hdfs://fenciDict.txt").collect()
for line in fenciDict:
jieba.add_word(line)
rdd= sqlDF.rdd.map(lambda x: (x.a, x.text_group)).map(lambda x: Row(a=x[0], text=",".join(jieba.cut(x[1]))))
rdd=rdd.map(lambda x: Row(a=x[1], text=x[0].split(",")))
preDf = rdd.map(lambda x: Row(a=x[1], text=etl.stopword(x[0]))).toDF()
preDf.cache()
cv = CountVectorizer(inputCol="text", outputCol="rawFeatures",vocabSize=2000)
cvModel = cv.fit(preDf)
cvResult = cvModel.transform(preDf)
idf = IDF(inputCol="rawFeatures", outputCol="features")
idfModel = idf.fit(cvResult)
tfidfResult = idfModel.transform(cvResult)
tfidfResult.cache()
voc = cvModel.vocabulary
L = range(0, 2000)
nvs = zip(L, voc)
nvDict = dict((id, word) for id, word in nvs)
def Index_toword(i):
word = nvDict[i]
return word
def intarr_index(intArr):
StrArr = []
for i in intArr:
StrArr.append(Index_toword(i))
return StrArr
def intArr2StrArr(intArr):
StrArr = []
for i in intArr:
StrArr.append(str(round(i, 4)))
return StrArr
lda = LDA(k=4, maxIter=80)
model = lda.fit(tfidfResult.select("a", "features"))
topics = model.describeTopics(6)
dfTopics = topics.rdd.map(lambda x: Row(topicId=x[0], termIndices=",".join(intarr_index(x[1])),termWeights=x[2])).toDF()
dfTopics=dfTopics.select(dfTopics['termIndices'], dfTopics['topicId'] + 1)
print("輸出主題詞,主題詞對應的權重分佈")
dfTopics.show(truncate=False)
transformed = model.transform(tfidfResult.select("a", "features"))
transformedrdd = transformed.rdd.map(lambda x: Row(a=x[0], topicDistribution=",".join(intArr2StrArr(x[2]))))
transformed = spark.createDataFrame(transformedrdd)
transformed_split=transformed.withColumn('topicDistribution', explode(split('topicDistribution', ',')))
transformed_split.cache()
transformed_split = transformed_split.select("a","topicDistribution", F.row_number().over(Window.partitionBy("a").orderBy("a")).alias("(topicId + 1)"))
transformed_split.cache()
w = Window.partitionBy('a')
DF=transformed_split.withColumn('maxtopicDistribution', F.max('topicDistribution').over(w))\
.where(F.col('topicDistribution') == F.col('maxtopicDistribution'))\
.drop('maxtopicDistribution')
DF.cache()
print("統計每個類別下的文本條數")
DF.groupBy("(topicId + 1)" ).count().show()
tagDf = sqlDF.join(DF,"a", "inner").select("(topicId + 1)",sqlDF.a,"text_group")
tagFinalDf = tagDf.join(dfTopics,"(topicId + 1)", "inner").select("termIndices","a","text_group")
tagFinalDf.show(100,truncate=False)
'''
迭代次數
根據評價指標:logLikelihood,logPerplexity判斷迭代次數
log likelihood,越大越好;
Perplexity評估,越小越好;
'''
'''
主題數目
迭代次數設爲65的前提下查看合適的主題個數;
主題數目,太少信息提取不全,太多主題分散;
'''
spark.stop()
print("運行結束時間:" + str(time.localtime(time.time()).tm_hour) + ":" + str(time.localtime(time.time()).tm_min) + "")
spark-submit --master yarn --jars etl.py --executor-memory 20G --total-executor-cores 12 ldaTest.py >>/test_lda_$(date +\%Y\%m\%d).log 2>&1 &