1. 首先,用MATLAB生成符合雙峯正態分佈的隨機數。
r=0.5;
mu1=50;
sigma1=20;
mu2=200;
sigma2=20;
x=zeros(10000,1);
for i=1:10000
r1=rand;
x(i,1)=(mu2+sigma2*randn)*heaviside(r1-r)+(mu1+sigma1*randn)*heaviside(r-r1);
end
hist(x)
2. 然後就是廢話少說,上代碼。
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
public class MatlabData {
double[] d = new double[10005];
double[] r = new double[10005];
int N=10000;
private double normal(double mu, double sigma, double x)
{
double exp = -(x-mu)*(x-mu)/(2*sigma*sigma);
return 1.0/(sigma*Math.sqrt(2*Math.PI))*Math.pow(Math.E, exp);
}
//http://blog.163.com/huai_jing@126/blog/static/17186198320119231094873/
private void EM(double pai, double mu1, double mu2, double s1, double s2)
{
int i,cnt=0;
for(i=1;i<=N;i++)
r[i] = 0;
double eps = 1e-6;
double opai=Integer.MIN_VALUE, omu1=Integer.MIN_VALUE, omu2=Integer.MIN_VALUE, os1=Integer.MIN_VALUE, os2=Integer.MIN_VALUE;
do{
prt("====================== " + (++cnt) + "\n");
for(i=1;i<=N;i++)
{
r[i] = pai*normal(mu2, s2, d[i])/((1-pai)*normal(mu1, s1, d[i])+pai*normal(mu2, s2, d[i]));
}
double up=0, down=0;
for(i=1;i<=N;i++)
{
up+=((1-r[i])*d[i]);
down+=(1-r[i]);
}
omu1 = mu1;
mu1 = up/down;
up = 0;
for(i=1;i<=N;i++)
{
up+=((1-r[i])*(d[i]-omu1)*(d[i]-omu1));
}
os1 = s1;
s1 = Math.sqrt(up/down);
up=0; down=0;
for(i=1;i<=N;i++)
{
up+=((r[i])*d[i]);
down+=(r[i]);
}
omu2 = mu2;
mu2 = up/down;
up = 0;
for(i=1;i<=N;i++)
{
up+=((r[i])*(d[i]-omu2)*(d[i]-omu2));
}
os2 = s2;
s2 = Math.sqrt(up/down);
opai = pai;
pai = down / N;
prt("mu1 : "+mu1+"\n");
prt("mu2 : "+mu2+"\n");
prt("s1 : "+s1+"\n");
prt("s2 : "+s2+"\n");
prt("pai : "+pai+"\n");
}while(Math.abs(mu1-omu1)>=eps
||Math.abs(mu2-omu2)>=eps
||Math.abs(s1-os1)>=eps
||Math.abs(s2-os2)>=eps
||Math.abs(pai-opai)>=eps);
}
public static void main(String[] args) throws IOException
{
MatlabData m = new MatlabData();
String iptPath = "data/data.txt";
BufferedReader br = new BufferedReader(new FileReader(iptPath));
String temp;
int i = 0;
while((temp = br.readLine())!=null)
{
m.d[++i] = Double.parseDouble(temp);
}
br.close();
prt("end of read : "+i+"!\n");
m.EM(0.1, 125.2, 125.4, 6023.4, 6023.4);
prt("end of the EM !\n");
}
private static void prt(String string)
{
System.out.print(string);
}
}
結果如下:
end of read : 10000!
====================== 1
...
====================== 9
mu1 : 50.04014214593736
mu2 : 199.9913405589886
s1 : 20.07997207253577
s2 : 19.995272104999646
pai : 0.5016524895149193
end of the EM !
可以跟第一步的參數對比一下,還是蠻準的嘛!
參考文獻:
EM算法(expectation-maximization algorithm)