中山大學人工智能實驗12 EM Algorithm (C++/Python)

描述

在這裏插入圖片描述

算法

在這裏插入圖片描述

任務

在這裏插入圖片描述

代碼

#讀取football文件,其中的數據是7維的x,表示歷年的世界盃和亞洲盃排名
dataSet = []
Country = []
fr = open("Football.txt")
for line in fr.readlines():
    curLine = line.strip().split('	')
    Country.append(curLine[0])
    curLine = curLine[1:]
    curLine = [float(ele) for ele in curLine]
    fltLine = curLine
    dataSet.append(fltLine)
print(Country)
dataSet

在這裏插入圖片描述

# -*- coding: utf-8 -*-
#  使用EM算法解算GGM  EM算法採用scikit-learn包提供的api
#  數據集:《機器學習》--西瓜數據4.0   :文件watermelon4.txt

from sklearn import mixture
import matplotlib.pyplot as plt
import numpy as np


# 預處理數據
def loadData(filename):
    dataSet = []
    fr = open(filename)
    for line in fr.readlines():
        curLine = line.strip().split(' ')
        fltLine = list(map(float, curLine))
        dataSet.append(fltLine)
    return dataSet


def test_GMM(dataMat, components=3,iter = 100,cov_type="full"):
    clst = mixture.GaussianMixture(n_components=n_components,max_iter=iter,covariance_type=cov_type)
    clst.fit(dataMat)
    predicted_labels =clst.predict(dataMat)
    print "Means:"
    print clst.means_
    print "Covariances Matrix:"
    print clst.covariances_
    print "Weights of each Gussian:"
    print clst.weights_
    return clst.means_,predicted_labels     # clst.means_返回均值



def showCluster(dataMat, k, centroids, clusterAssment):
    numSamples, dim = dataMat.shape
    if dim != 2:
        print("Sorry! I can not draw because the dimension of your data is not 2!")
        return 1

    mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']
    if k > len(mark):
        print("Sorry! Your k is too large!")
        return 1

        # draw all samples
    for i in range(numSamples):
        markIndex = int(clusterAssment[i])
        plt.plot(dataMat[i, 0], dataMat[i, 1], mark[markIndex])

    mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
    # draw the centroids
    for i in range(k):
        plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12)

    plt.show()


if __name__=="__main__":
    n_components = 3
    iter=100
    cov_types = ['spherical', 'tied', 'diag', 'full']
    centroids,labels = test_GMM(dataSet,n_components,iter,cov_types[3])
    #因爲是二維畫圖,我們不能一次把7維數據都畫出來,這裏用前兩維畫圖
    showCluster(np.mat(dataSet)[:,0:2], n_components, centroids, labels)
rank = [0,1,2]
rank.sort(key=lambda x:sum(centroids[x]))

cnt = 1
for i in rank:
    print "Class",cnt,":"
    for idx in range(len(labels)):
        if labels[idx]==i:
            print Country[idx]
    cnt+=1

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章