高斯混合模型的C++實現

高斯混合模型的C++實現

原理

GMM將數據的分佈通過多個高斯模型進行擬合。GMM是一種聚類算法,每個component就是一個聚類中心。高斯混合模型可以得到每個數據屬於每個模型的概率,是一種軟聚類算法。這是來自《統計學習方法》中的定義:

這裏寫圖片描述

過程

高斯混合模型使用EM算法估計模型參數。
1. 初始化模型的個數和每個高斯模型的參數,設定迭代結束條件(迭代次數,誤差閾值)
2. 迭代:對於每一個數據,計算在每一個高斯模型中的概率
3. 根據計算得到的概率更新每個模型的參數(均值,方差)
4. 當超過迭代次數或者更新小於閾值時結束迭代。

代碼實現

代碼參考自網上大神,具體出處忘記了,加以修改和添加註釋,如有侵權請聯繫~

頭文件

#ifndef _GMM_H
#define _GMM_H
#include <vector>
#include <cmath>
using namespace std;
class GMM
{
public:
    void Init(const vector<double> &inputData, const int clustNum = 5, double eps = 0.01, double max_steps = 20);
    void train();
    int predicate(double x);//預測輸入的數據屬於哪一類
    void print();
protected:
    int clusterNum;             // 限制
    vector<double> means;
    vector<double> means_bkp;       // 上一次的迭代數據
    vector<double> sigmas;
    vector<double> sigmas_bkp;      // 上一次的迭代數據
    vector<double> probilities;
    vector<double> probilities_bkp;
    vector<vector<double>> memberships;     // 存儲屬於哪一個類別
    vector<vector<double>> memberships_bkp;
    vector<double> data;
    int dataNum;        // 數據數量
    double epslon;      // 相差的閾值
    double max_steps;   // 迭代次數
private:
    double gauss(const double x, const double m, const double sigma);
};
#endif

實現

#include "GMM.h"
#include <iostream>
#include <fstream>
#include <stdlib.h>
#include <Windows.h>
using namespace std;

void GMM::Init(const vector<double> &inputData, const int clustNum, double eps, double max_steps)
{
    /*獲取輸入數據*/
    this->data = inputData;
    this->dataNum = data.size();
    /*存儲最終需要的結果*/
    this->clusterNum = clustNum;        // 聚類數量
    this->epslon = eps;                 // 閾值
    this->max_steps = max_steps;        // 最大的迭代次數
    /*保留每一個類別的均值,方差參數,保留上一個的參數*/
    this->means.resize(clusterNum);
    this->means_bkp.resize(clusterNum);
    this->sigmas.resize(clusterNum);
    this->sigmas_bkp.resize(clusterNum);
    /*保留每一個數據對於每一個類別下的概率,由在這個類別下的概率除以到各個類別的總概率得到*/
    this->memberships.resize(clusterNum);
    this->memberships_bkp.resize(clusterNum);
    for (int i = 0; i < clusterNum; i++)
    {
        memberships[i].resize(data.size());
        memberships_bkp[i].resize(data.size());
    }
    /*每一個類別的可能性*/
    this->probilities.resize(clusterNum);
    this->probilities_bkp.resize(clusterNum);
    //initialize mixture probabilities 初始化每個類別的參數
    for (int i = 0; i < clusterNum; i++)
    {
        probilities[i] = probilities_bkp[i] = 1.0 / (double)clusterNum;
        //init means
        means[i] = means_bkp[i] = 255.0*i / (clusterNum);
        //init sigma
        sigmas[i] = sigmas_bkp[i] = 50;
    }
}

void GMM::train()
{
    //compute membership probabilities
    int i, j, k, m;
    double sum = 0, sum2;
    int steps = 0;
    bool go_on;
    do          // 迭代
    {
        for (k = 0; k < clusterNum; k++)
        {
            //compute membership probabilities
            for (j = 0; j < data.size(); j++)
            {
                //計算p(k|n),計算每一個數據在每一個類別的值的加權和
                sum = 0;
                for (m = 0; m < clusterNum; m++)
                {
                    sum += probilities[m] * gauss(data[j], means[m], sigmas[m]);
                }
                //求分子,第j個數據在第k類中的所佔的比例
                memberships[k][j] = probilities[k] * gauss(data[j], means[k], sigmas[k]) / sum;
            }
            //求均值
            //求條件概率的和,將每個數據在第k類中的概率進行累加
            sum = 0;
            for (i = 0; i < dataNum; i++)
            {
                sum += memberships[k][i];
            }
            //得到每個數據在屬於第k類的加權值之和
            sum2 = 0;
            for (j = 0; j < dataNum; j++)
            {
                sum2 += memberships[k][j] * data[j];
            }
            //得到新的均值 由概率加權和除以總概率作爲均值
            means[k] = sum2 / sum;
            //求方差  由到均值的平方的加權和作爲新的方差
            sum2 = 0;
            for (j = 0; j < dataNum; j++)
            {
                sum2 += memberships[k][j] * (data[j] - means[k])*(data[j] - means[k]);
            }
            sigmas[k] = sqrt(sum2 / sum);
            //求概率
            probilities[k] = sum / dataNum;
        }//end for k
        //check improvement
        go_on = false;
        for (k = 0; k<clusterNum; k++)
        {
            if (means[k] - means_bkp[k]>epslon)
            {
                go_on = true;
                break;
            }
        }
        //back up
        this->means_bkp = means;
        this->sigmas_bkp = sigmas;
        this->probilities_bkp = probilities;
    } while (go_on&&steps++ < max_steps); //end do while
}

double GMM::gauss(const double x, const double m, const double sigma)
{
    return 1.0 / (sqrt(2 * 3.1415926)*sigma)*exp(-0.5*(x - m)*(x - m) / (sigma*sigma));
}

// 預測
int GMM::predicate(double x)
{
    double max_p = -100;
    int i;
    double current_p;
    int bestIdx = 0;
    for (i = 0; i < clusterNum; i++)
    {
        current_p = gauss(x, means[i], sigmas[i]);
        if (current_p > max_p)
        {
            max_p = current_p;
            bestIdx = i;
        }
    }
    return bestIdx;
}

void GMM::print()
{
    int i;
    for (i = 0; i < clusterNum; i++)
    {
        cout << "Mean: " << means[i] << " Sigma: " << sigmas[i] << " Mixture Probability: " << probilities[i] << endl;
    }
}

github地址

如有錯誤,歡迎指出~

發佈了36 篇原創文章 · 獲贊 65 · 訪問量 15萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章