简介
EM算法是十大数据挖掘算法,我原来看过不少讲解EM的blog,但是都觉得是晦涩的讲解和分析,今天突然整理到了,发现自己醍醐灌顶,身上不知道什么地方被打通了,哈哈,总之总结一套我认为很好理解的EM分析。后面还有python实例。
EM算法
引入
EM算法是一种优化算法,用途就是根据观测到的样本得到模型的参数。
这个场景似乎大家很熟悉,用样本来估计模型参数,那不就是极大似然估计嘛。但是当模型中存在我们不能观测到的未知变量Z的时候,很明显,极大似然估计就不那么好用了。
那么有人说可以用梯度下降吗?
哇哦,其实这么想没错,不就是多了几个隐变量吗,真正明白梯度下降优化方法的朋友应该知道梯度下降是可以应对隐变量的优化问题的。
但是:
计算的复杂度随着隐变量的数目以指数形式上升,这样计算非常麻烦。
似然方程
由于无法直接求解,所以首先我们对隐变量Z计算期望,最大化一观测数据的对数边际似然:
迭代式算法
- 首先给初始参数设置初始值
- E步:此时模型参数已知(theta已知),可以根据训练数据推断最优隐变量Z的数值
- M步:此时Z已知,对theta进行极大似然估计
- 计算似然方程,检查是否达到收敛标准,if not 进行第二步 else 返回;
算法
首先设置参数
- 计算第i个样本属于第k类的后验概率
- 更新
- 计算似然函数,是否达到停止标准,如果没有进行1。否则返回。
EM实现GMM优化
简介
这里我们将提供一个EM算法解决混合高斯模型参数估计的demo。
当然这种方法一般被用来进行聚类分析。
不过这次我们假设有两个高斯分布,在不知道 均值、方差、模型占比的情况下,我们仅使用观测到的样本来得到这三个模型参数的估计值。
分析:在一般情况下,一个样本属于哪一个高斯分布,也就是类别,是我们需要最终求解的,但是这时候就存在很多位置的隐变量:均值,方差。所以一般的优化方法就不适用了。我们选择EM算法。
代码
import numpy as np
import matplotlib.pyplot as plt
def gaussian(x, mu, sigma):
return np.exp(-(x-mu)**2/(2*sigma**2))/(np.sqrt(2*np.pi)*sigma)
N_boys=77230#比重77.23%
N_girls=22770#比重22.77%
N=N_boys+N_girls#观测集大小
K=2#高斯分布模型的数量
np.random.seed(1)
#男生身高数据
mu1=1.74#均值
sig1=0.0865#标准差
BoyHeights=np.random.normal(mu1,sig1,N_boys)#返回随机数
BoyHeights.shape=N_boys,1
#女生身高数据
mu2=1.63
sig2=0.0642
GirlHeights=np.random.normal(mu2,sig2,N_girls)#返回随机数
GirlHeights.shape=N_girls,1
data=np.concatenate((BoyHeights,GirlHeights))
u = np.random.random((1,2))
sigma = np.random.random((1,2))
a=np.random.random()
b=1-a
pi = np.array([[a,b]])
epoch = 0#迭代次数
while(True):
probability1 = gaussian(data, u[0][0], np.sqrt(sigma[0][0]))
probability2 = gaussian(data, u[0][1], np.sqrt(sigma[0][1]))
g1 = probability1 * pi[0][0]
g2 = probability2 * pi[0][1]
gg = g1 + g2
sigma[0][0] = np.dot((g1 / gg).T, (data - u[0][0])**2) / np.sum(g1 / gg)
sigma[0][1] = np.dot((g2 / gg).T, (data - u[0][1])**2) / np.sum(g2 / gg)
u[0][0] = np.dot((g1 / gg).T, data) / np.sum(g1 / gg)
u[0][1] = np.dot((g2 / gg).T, data) / np.sum(g2 / gg)
pi[0][0] = np.sum(g1 / gg) / N
pi[0][1] = np.sum(g2 / gg) / N
if epoch % 500 == 0:
print("第", epoch, "次迭代:")
print("u:", u)
print("sigma:", np.sqrt(sigma))
if epoch == 3000:
break
epoch += 1
最终得到结果:
大家共勉~~