SparkML之聚類(二)高斯混合模型(GMMs)



1、閒聊

在講高斯混合模型,我們先拋開一切,來一些推導。推導前,假設你認可兩個統計學基礎的兩個定理

(1)大數定理(2)中心極限定理


聯合實際情況就是說,假如我們坐在廣州地鐵1號線的某個地方進行蹲點1天,記錄下地鐵全部女性的身高。這一天下來她們的身高的均值和方差。和我們第二天繼續第一天的工作得到的均值和方差是接近的。而且服從高斯分佈。

面可以證明每個點產生的概率值聯合起來爲什麼是一個“鐘形”曲線。也就是證明高斯的分佈函數形式。


有了這個關於高斯密度分佈函數,那麼我們在上面取一個x就可以得到一個概率。它的意義在於,我們某一天又在哪

裏蹲點。我們就可以說下一個女人身高爲x的概率是多少。但是這裏有兩個假設

(1)任何一個女性都有可能在這裏出現

(2)整個過程沒有女性死亡和出生

但是現在細分下去就會出現的問題是,身高和區域有關

也就是說在北京地鐵蹲點得到的身高比廣州地鐵蹲點得到身高會“普遍更高”。也就是出現我們在預測下一個的時候x,輸入x到高斯分佈函數函數中,得到的值會比真值“更低”或者“更高”的現象。這就說明一點高斯分佈是和羣體有關。

也可以得到一點的是:不同的羣體,高斯分佈的均值和方差不一樣。

就是根據這樣的思想,所以就可以利用高斯分佈來聚類。


2、高斯混合模型(Gaussian Mixture Models (GMMs))

高斯和密度函數估計是一種參數化模型,有SGM(Single Gaussian Model)和GMM(Gaussian Mixture Model)



3、EM算法

對於EM算法可以參考andrew NG 的 <The EM Algorithm>[1],下面來簡述EM算法步驟:


4、Matlab實現高斯混合模型分類

load gmm_data.txt
%gmm_data.txt爲spark源碼 data下的數據
X = gmm_data';
%設置分類數 k
k = 2;

[z,model,llh] = myEM(X,k);
figure
plot(llh);
xlabel('迭代次數')
ylabel('log-likelihood')
figure
gscatter(X(1,:),X(2,:),z)
%模型參數
model.E
%ans(:,:,1) =
%    4.9060   -2.0062
%   -2.0062    1.0112

%ans(:,:,2) =
 %   4.7809    1.8768
 %   1.8768    0.9149
 
model.mu
%ans =
%   -0.1044    0.0722
%    0.0429    0.0167
model.w
%ans =
%    0.5196    0.4804
function[label,model,llh]=myEM(X,k)
% 輸入:
%     X 是輸入樣本集
%     k 是待分類別數
% 返回
%      label: X對應的標籤
%      model: GMMS
%      llh::log之後的極大似然(log-likelihood)

%指標
tol = 1e-8;
maxIter = 1000;

%初始化
n = length(X);
label = ceil(k*rand(1,n));
Z = full(sparse(1:n,label,1,n,k,n));
llh = -inf(1,maxIter); 

%迭代優化
for iter = 2:maxIter
    [~,label(1,:)] = max(Z,[],2);
    model = M_Step(X,Z); %計算 model
    [Z, llh(iter)] = E_Step(X,model);
    if abs(llh(iter)-llh(iter-1)) < tol*abs(llh(iter)); break; end;
end
llh = llh(2:iter);





function model = M_Step(X,Z)
[d,n] = size(X);%d是X的維度 
k = size(Z,2);%k 是類別
nk = sum(Z,1);% 計算類別下的數目
w = nk/n;%權重
mu = bsxfun(@times, X*Z, 1./nk);%計算mu

E = zeros(d,d,k);%協方差矩陣
r = sqrt(Z);
for i = 1:k
    Xo = bsxfun(@minus,X,mu(:,i));
    Xo = bsxfun(@times,Xo,r(:,i)');
    E(:,:,i) = Xo*Xo'/nk(i)+eye(d)*(1e-6);
end

model.mu = mu;
model.E = E;
model.w = w;


function [Z, llh] = E_Step(X, model)
mu = model.mu;
E = model.E;
w = model.w;

n = size(X,2);
k = size(mu,2);
Z = zeros(n,k);
for i = 1:k
    Z(:,i) = loggausspdf(X,mu(:,i),E(:,:,i));
end
Z = bsxfun(@plus,Z,log(w));
T = logsumexp(Z,2);
llh = sum(T)/n;
Z = exp(bsxfun(@minus,Z,T));


function y = loggausspdf(X, mu, E)
d = size(X,1);
X = bsxfun(@minus,X,mu);
U= chol(E);
Q = U'\X;
q = dot(Q,Q,1); 
c = d*log(2*pi)+2*sum(log(diag(U))); 
y = -(c+q)/2;



圖像結果:


Spark源碼圖(大圖見附錄)


 Spark實驗

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.clustering.{GaussianMixture}
import org.apache.spark.mllib.linalg.Vectors


object GaussianMixtureExample {

  def main(args: Array[String]) {

    val conf = new SparkConf().setAppName("GaussianMixtureExample").setMaster("local")
    val sc = new SparkContext(conf)


    // Load and parse the data
    val data = sc.textFile("C:\\Users\\alienware\\IdeaProjects\\sparkCore\\data\\mllib\\gmm_data.txt")
    val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble))).cache()

    // Cluster the data into two classes using GaussianMixture
    val gmm = new GaussianMixture().setK(2).run(parsedData)

    // Save and load model
    //gmm.save(sc, "target/org/apache/spark/GaussianMixtureExample/GaussianMixtureModel")
    //val sameModel = GaussianMixtureModel.load(sc,"target/org/apache/spark/GaussianMixtureExample/GaussianMixtureModel")

    // output parameters of max-likelihood model
    for (i <- 0 until gmm.k) {
      println("weight=%f\nmu=%s\nsigma=\n%s\n" format
        (gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma))
    }

    /**
      * weight=0.481027
        mu=[0.07217189762937483,0.0166693219789788]
        sigma=
        4.776177833343064  1.874381267946877
        1.874381267946877  0.9140182655978455

        weight=0.518973
        mu=[-0.10458625214505539,0.042897423244107544]
        sigma=
        4.910485828947743   -2.008602407570325
        -2.008602407570325  1.0121329041756117
      */


    sc.stop()
  }
}

   


參考文獻

http://cs229.stanford.edu/notes/cs229-notes8.pdf


發佈了70 篇原創文章 · 獲贊 211 · 訪問量 31萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章