代碼:
def three_coins(pi,p,q,y,steps):
e = 0.00001
# E-step
for _ in range(steps):
miu = []
for yi in y:
miu.append((pi*(p**yi)*((1-p)**(1-yi)))/(pi*(p**yi)*((1-p)**(1-yi))+(1-pi)*(q**yi)*((1-q)**(1-yi))))
# M-step
temp = list(map(lambda x:1-x,miu))
new_pi = sum(miu)/len(miu)
new_p = sum(map(lambda t:t[0]*t[1],list(zip(miu,y))))/sum(miu)
new_q = sum(map(lambda t:t[0]*t[1],list(zip(temp,y))))/sum(temp)
print(new_pi, new_p, new_q)
# check if converge
if abs(pi-new_pi) < e and abs(p-new_p) < e and abs(q-new_q) < e:
print("Done")
break
pi, p, q = new_pi, new_p, new_q
運行:
# sample input
input_y = [1,1,0,1,0,0,1,0,1,1]
# initial parameters
init_pi, init_p, init_q, iter_num = 0.5, 0.5, 0.5, 100
# run
three_coins(init_pi, init_p, init_q, iter_num)
輸出:
0.5 0.6 0.6
0.5 0.6 0.6
更換不同的初始值, EM算法收斂的值會有所不同.
# initial parameters
init_pi, init_p, init_q, iter_num = 0.4, 0.6, 0.7, 100
# run
three_coins(init_pi, init_p, init_q, iter_num)
0.40641711229946526 0.5368421052631579 0.6432432432432431
0.40641711229946537 0.5368421052631579 0.6432432432432431