Training of GMM

   太多公式,只能發圖,全英版提高逼格~













以下是GMM訓練的matlab代碼,爲了加快收斂的速度,在訓練之前加入了K-means聚類算法,初步估計聚類的means,weights,variences;


function TrainGMM()
    clear;
    Sample=getGMMSample(10000);
    %利用K_Means方法現對樣本進行聚類,減少GMM的計算次數,加快GMM的訓練速度
    [weights0,means0,conv0]=K_Means(Sample,3);%利用K-means方法加快收斂
    [weights,means,conv]=estimateGMM(Sample,3,weights0,means0,conv0);


function Sample=getGMMSample(SampleNum)
%符合混合高斯分佈的數據生成函數
    weight1=0.2;
    miu1=30; %第一個分佈的參數 
    sigma1=10;%第一個分佈的參數 標準差
    weight2=0.5;
    miu2=100;%第二個分佈的參數 
    sigma2=25;%第二個分佈的參數 標準差
    weight3=0.3;
    miu3=190;%第二個分佈的參數 
    sigma3=20;%第二個分佈的參數 標準差
    x=0:1:255; 
    y1=normcdf(x,miu1,sigma1); 
    y2=normcdf(x,miu2,sigma2); 
    y3=normcdf(x,miu3,sigma3);
    y=weight1*y1+weight2*y2+weight3*y3; 
    u=0.002+0.996*rand(1,SampleNum);
    Sample=round(interp1(y,x,u,'linear'));
    
    
function [weights,means,conv]=estimateGMM(Sample,numOfGaussiant,weights0,means0,conv0)
    numOfSample=size(Sample,2);
    weights=weights0;
    means=means0;
    conv=conv0;
    qiutThread=1e-5;
    %calculate the P(q|X,weights,means,conv)
    while(1)
        GaussianOutput=zeros(numOfGaussiant,numOfSample);
        for i=1:1:numOfGaussiant
            %在現有假設的高斯模型下,對當前樣本的概率進行計算   
            % weight(i)*N(xj|(i)th compose of the old model ) j=[0:1:numOfSample];
            singleGaussiant=normpdf(Sample,means(i),conv(i));
            GaussianOutput(i,:)=weights(i).*singleGaussiant;
        end
        %E-step calculate the expatation
        %計算P(xi|old model)
        oldGMMForOutput=sum(GaussianOutput,1);
        eachComposeContributeToOutput=zeros(numOfGaussiant,numOfSample);
        for i=1:1:numOfGaussiant
            %計算P(q|xj,old model) 
            %  =weight(i)*P(xj|(i)th compose of the old model)/P(xj|old model)
            eachComposeContributeToOutput(i,:)=GaussianOutput(i,:)./oldGMMForOutput;
        end
        %M-step maxmize the estimation
        %計算sum(P(q|xi,old model))
        Nq=sum(eachComposeContributeToOutput,2)';   
        %計算xi*P(q|xi,old model)
        tempmartixForMeans=zeros(numOfGaussiant,1);
        for j=1:1:numOfGaussiant
            tempmartixForMeans(j,:)=eachComposeContributeToOutput(j,:)*Sample';% eachComposeContributeToOutput(j,:)是行向量 Sample' 是列向量
        end
        %update means;
        means=(tempmartixForMeans')./Nq
        
        tempmartixForConvs=zeros(numOfGaussiant,1);
        for j=1:1:numOfGaussiant
            tempmartixForConvs(j,:)=eachComposeContributeToOutput(j,:)*((Sample-means(j)).^2)';
        end
        %update conv
        conv=sqrt((tempmartixForConvs')./Nq) %方差開方變標準差
        %update weight
        weights=Nq./numOfSample
        pause(1);
        %畫圖
        drawDistribution(Sample,weights,means,conv);
        %退出條件:產生序列的似然概率變化達到閾值
        loglikehoodOldModel=sum(log(ModelOutputProbility(Sample,numOfGaussiant,weights0,means0,conv0)));
        loglikehoodNewModel=sum(log(ModelOutputProbility(Sample,numOfGaussiant,weights,means,conv)));
        if abs(loglikehoodOldModel-loglikehoodNewModel)<qiutThread
            break;
        else
            weights0=weights;
            means0=means;
            conv0=conv;
        end
    end
    
function [weights,means,convs]=K_Means(Sample,numOfCluser)
    sampleNum=size(Sample,2);
    means=round(255*rand(1,3));%隨機得到分類中心
    means0=means;
    threadHold=1e-5;%停止閾值
    while(1)
        %計算每個樣本到達分類中心的距離
        distanceSampleMeans=zeros(numOfCluser,sampleNum);
        for i=1:1:numOfCluser
            distanceSampleMeans(i,:)=Sample-means(i);%計算距離:此處計算的是一維的歐氏距離。
        end
        absDistanceSampleMeans=abs(distanceSampleMeans);%取絕對值,方便下一步將點歸類
        %將點分配到距離它最近的類中
        cluserMartrix=zeros(numOfCluser,sampleNum);
        for i=1:1:sampleNum
            [minValue,minimumLocation]=min(absDistanceSampleMeans(:,i));%計算最近距離的位置
            cluserMartrix(minimumLocation,i)=1;
        end
        pointAsignToCluser=sum(cluserMartrix,2);%計算分配到每個類的點的個數
        %計算重新分配後的點聚類羣的均值
        cluserCenter=zeros(1,numOfCluser);
        for i=1:1:numOfCluser
            cluserCenter(1,i)=(Sample*cluserMartrix(i,:)')./pointAsignToCluser(i);
        end
        means=cluserCenter;%更新均值
        weights=pointAsignToCluser'./sampleNum;%更新權重
        if min(abs(means-means0))<threadHold %求means最小變化值是否已經小於閾值
            %重新計算分類:
            cluserMartrix=zeros(numOfCluser,sampleNum);
            for i=1:1:numOfCluser
                distanceSampleMeans(i,:)=Sample-means(i);
            end
            absDistanceSampleMeans=abs(distanceSampleMeans);
            cluserMartrix=zeros(numOfCluser,sampleNum);
            for i=1:1:sampleNum
                [minValue,minimumLocation]=min(absDistanceSampleMeans(:,i));
                cluserMartrix(minimumLocation,i)=1;
            end
            %###################################
            convs=zeros(1,numOfCluser);
            for i=1:1:numOfCluser
                bufferSample=((Sample-means(i)).^2)*cluserMartrix(i,:)';
                convs(1,i)=sqrt(bufferSample/sum(cluserMartrix(i,:)));%標準差
            end
            %畫混合高斯模型圖:利用K-means方法估計的權重,均值和標準差
            X=[0:1:255];
            Y=weights(1)*normpdf(X,means(1),convs(1))+weights(2)*normpdf(X,means(2),convs(2))+weights(3)*normpdf(X,means(3),convs(3));
            subplot(3,1,1);
            plot(X,Y);
            %
            break;
        else 
            means0=means;
        end
    end

 
function ProbilityVector=ModelOutputProbility(Sample,numOfGaussiant,weights,means,conv)
    %P(xi|weights,means,conv)
    GaussianOutput=zeros(numOfGaussiant,size(Sample,2));
    for i=1:1:numOfGaussiant
        %在高斯模型下,對當前樣本的概率進行計算   
        % weight(i)*N(xj|(i)th compose of the old model ) j=[0:1:numOfSample];
        GaussianOutput(i,:)=weights(i)*normpdf(Sample,means(i),conv(i));
    end
    %E-step calculate the expatation
    %計算P(xi|old model)
    ProbilityVector=sum(GaussianOutput,1);
    
    
function  drawDistribution(Sample,weights,means,conv)
    %畫圖
    %數據分佈
    xx=1:1:255;
    yy=hist(Sample,xx);
    subplot(3,1,2)
    bar(xx,yy);
    %混合模型
    X=[0:1:255];
    Y=weights(1)*normpdf(X,means(1),conv(1))+weights(2)*normpdf(X,means(2),conv(2))+weights(3)*normpdf(X,means(3),conv(3));
    subplot(3,1,3);
    plot(X,Y);
  

算法運行效果截圖:


發佈了23 篇原創文章 · 獲贊 6 · 訪問量 4萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章