三硬幣問題的EM算法實現

代碼:

def three_coins(pi,p,q,y,steps):
    e = 0.00001
    # E-step
    for _ in range(steps):
        miu = []
        for yi in y:
            miu.append((pi*(p**yi)*((1-p)**(1-yi)))/(pi*(p**yi)*((1-p)**(1-yi))+(1-pi)*(q**yi)*((1-q)**(1-yi))))
    # M-step
        temp = list(map(lambda x:1-x,miu))
        new_pi = sum(miu)/len(miu)
        new_p = sum(map(lambda t:t[0]*t[1],list(zip(miu,y))))/sum(miu)
        new_q = sum(map(lambda t:t[0]*t[1],list(zip(temp,y))))/sum(temp)
        print(new_pi, new_p, new_q)
    # check if converge
        if abs(pi-new_pi) < e and abs(p-new_p) < e and abs(q-new_q) < e:
            print("Done")
            break
        pi, p, q = new_pi, new_p, new_q

運行:

# sample input
input_y = [1,1,0,1,0,0,1,0,1,1]
# initial parameters
init_pi, init_p, init_q, iter_num = 0.5, 0.5, 0.5, 100
# run
three_coins(init_pi, init_p, init_q, iter_num)

輸出:

0.5 0.6 0.6
0.5 0.6 0.6

更換不同的初始值, EM算法收斂的值會有所不同.

# initial parameters
init_pi, init_p, init_q, iter_num = 0.4, 0.6, 0.7, 100
# run
three_coins(init_pi, init_p, init_q, iter_num)
0.40641711229946526 0.5368421052631579 0.6432432432432431
0.40641711229946537 0.5368421052631579 0.6432432432432431
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章