簡介
EM算法是十大數據挖掘算法,我原來看過不少講解EM的blog,但是都覺得是晦澀的講解和分析,今天突然整理到了,發現自己醍醐灌頂,身上不知道什麼地方被打通了,哈哈,總之總結一套我認爲很好理解的EM分析。後面還有python實例。
EM算法
引入
EM算法是一種優化算法,用途就是根據觀測到的樣本得到模型的參數。
這個場景似乎大家很熟悉,用樣本來估計模型參數,那不就是極大似然估計嘛。但是當模型中存在我們不能觀測到的未知變量Z的時候,很明顯,極大似然估計就不那麼好用了。
那麼有人說可以用梯度下降嗎?
哇哦,其實這麼想沒錯,不就是多了幾個隱變量嗎,真正明白梯度下降優化方法的朋友應該知道梯度下降是可以應對隱變量的優化問題的。
但是:
計算的複雜度隨着隱變量的數目以指數形式上升,這樣計算非常麻煩。
似然方程
由於無法直接求解,所以首先我們對隱變量Z計算期望,最大化一觀測數據的對數邊際似然:
迭代式算法
- 首先給初始參數設置初始值
- E步:此時模型參數已知(theta已知),可以根據訓練數據推斷最優隱變量Z的數值
- M步:此時Z已知,對theta進行極大似然估計
- 計算似然方程,檢查是否達到收斂標準,if not 進行第二步 else 返回;
算法
首先設置參數
- 計算第i個樣本屬於第k類的後驗概率
- 更新
- 計算似然函數,是否達到停止標準,如果沒有進行1。否則返回。
EM實現GMM優化
簡介
這裏我們將提供一個EM算法解決混合高斯模型參數估計的demo。
當然這種方法一般被用來進行聚類分析。
不過這次我們假設有兩個高斯分佈,在不知道 均值、方差、模型佔比的情況下,我們僅使用觀測到的樣本來得到這三個模型參數的估計值。
分析:在一般情況下,一個樣本屬於哪一個高斯分佈,也就是類別,是我們需要最終求解的,但是這時候就存在很多位置的隱變量:均值,方差。所以一般的優化方法就不適用了。我們選擇EM算法。
代碼
import numpy as np
import matplotlib.pyplot as plt
def gaussian(x, mu, sigma):
return np.exp(-(x-mu)**2/(2*sigma**2))/(np.sqrt(2*np.pi)*sigma)
N_boys=77230#比重77.23%
N_girls=22770#比重22.77%
N=N_boys+N_girls#觀測集大小
K=2#高斯分佈模型的數量
np.random.seed(1)
#男生身高數據
mu1=1.74#均值
sig1=0.0865#標準差
BoyHeights=np.random.normal(mu1,sig1,N_boys)#返回隨機數
BoyHeights.shape=N_boys,1
#女生身高數據
mu2=1.63
sig2=0.0642
GirlHeights=np.random.normal(mu2,sig2,N_girls)#返回隨機數
GirlHeights.shape=N_girls,1
data=np.concatenate((BoyHeights,GirlHeights))
u = np.random.random((1,2))
sigma = np.random.random((1,2))
a=np.random.random()
b=1-a
pi = np.array([[a,b]])
epoch = 0#迭代次數
while(True):
probability1 = gaussian(data, u[0][0], np.sqrt(sigma[0][0]))
probability2 = gaussian(data, u[0][1], np.sqrt(sigma[0][1]))
g1 = probability1 * pi[0][0]
g2 = probability2 * pi[0][1]
gg = g1 + g2
sigma[0][0] = np.dot((g1 / gg).T, (data - u[0][0])**2) / np.sum(g1 / gg)
sigma[0][1] = np.dot((g2 / gg).T, (data - u[0][1])**2) / np.sum(g2 / gg)
u[0][0] = np.dot((g1 / gg).T, data) / np.sum(g1 / gg)
u[0][1] = np.dot((g2 / gg).T, data) / np.sum(g2 / gg)
pi[0][0] = np.sum(g1 / gg) / N
pi[0][1] = np.sum(g2 / gg) / N
if epoch % 500 == 0:
print("第", epoch, "次迭代:")
print("u:", u)
print("sigma:", np.sqrt(sigma))
if epoch == 3000:
break
epoch += 1
最終得到結果:
大家共勉~~