EM算法初始值設定的影響 (三硬幣模型)

先自定義pi,p,q的值,隨機生成實驗結果,

真實 pi:0.3500.p:0.6600.q:0.3300
產生的實驗數據 pi:0.3522.p:0.6729.q:0.3341

將實驗結果帶入EM算法估計pi,p,q的值

 

第一次theta的初始值設置的裏真實值比較遠,

初值設定爲 pi:0.5000.p:0.5000.q:0.5000

運行結果爲 pi:0.5000.p:0.4534.q:0.4534

 

第二次theta的初始值設置的裏真實值比較近,

初值設定爲 pi:0.3000.p:0.6000.q:0.3000

運行結果爲 pi:0.3168.p:0.6606.q:0.3573

 

可以看出EM算法非常依賴初始值的正確設定

 

代碼:

import numpy as np

def prepareData(pi, p, q):
    sampleNum = 10000
    z_array = np.random.binomial(1, pi, sampleNum)
    p_array = np.random.binomial(1, p, np.sum(z_array))
    q_array = np.random.binomial(1, q, sampleNum - sum(z_array))
    p_count = 0
    q_count = 0
    result_array = np.ndarray(sampleNum)
    for i in range(sampleNum):
        if (z_array[i] == 1):
            result_array[i] = p_array[p_count]
            p_count += 1
        else:
            result_array[i] = q_array[q_count]
            q_count += 1

    print('真實 pi:{:.4f}.p:{:.4f}.q:{:.4f}'.format(pi, p, q))
    print('數據 pi:{:.4f}.p:{:.4f}.q:{:.4f}'.format(len(p_array) / sampleNum, np.sum(p_array) / len(p_array),
                                                  np.sum(q_array) / len(q_array)))

    return result_array


def em_update(theta, y_array):
    pi = theta[0]
    p = theta[1]
    q = theta[2]

    p_prob = pi * pow(p, y_array) * pow(1 - p, 1 - y_array)
    q_prob = (1 - pi) * pow(q, y_array) * pow(1 - q, 1 - y_array)
    mu_array = p_prob / (p_prob + q_prob)
    pi_new = np.sum(mu_array) / len(mu_array)
    p_new = np.sum(mu_array * y_array) / np.sum(mu_array)
    q_new = np.sum((1 - mu_array) * y_array) / np.sum((1 - mu_array))
    return (pi_new, p_new, q_new)


def em(theta, y_array, iterateNum):

    print('初值 pi:{:.4f}.p:{:.4f}.q:{:.4f}'.format(theta[0], theta[1], theta[2]))

    for i in range(iterateNum):
        theta_new = em_update(theta, y_array)
        if (theta_new == theta):
            break
        else:
            theta = theta_new

    print('迭代次數:',i+1)
    print('估計 pi:{:.4f}.p:{:.4f}.q:{:.4f}'.format(theta_new[0], theta_new[1], theta_new[2]))
    return theta_new


def run():
    # 自定義真實數據
    pi=0.35
    p=0.66
    q=0.33

    # 建立實驗結果
    y_array=prepareData(pi,p,q)
    print('******************************')
    # em算法推算
    print("初值與真實值差異較大時")
    theta_init = (0.5, 0.5, 0.5)
    theta_predict = em(theta_init, y_array, 10000)
    print('差異數 pi:{:.2f}.p:{:.2f}.q:{:.2f}'.format(abs(theta_predict[0]-pi),abs(theta_predict[1]-p),abs(theta_predict[2]-q)))
    print('******************************')
    print("初值與真實值差異較小時")
    theta_init = (0.3, 0.6, 0.3)
    theta_predict = em(theta_init, y_array, 10000)
    print('差異數 pi:{:.2f}.p:{:.2f}.q:{:.2f}'.format(abs(theta_predict[0]-pi),abs(theta_predict[1]-p),abs(theta_predict[2]-q)))



if __name__ == '__main__':
    import sys

    run()
    sys.exit(0)

 

運行結果

真實 pi:0.3500.p:0.6600.q:0.3300
數據 pi:0.3522.p:0.6729.q:0.3341
******************************
初值與真實值差異較大時
初值 pi:0.5000.p:0.5000.q:0.5000
迭代次數: 2
估計 pi:0.5000.p:0.4534.q:0.4534
差異數 pi:0.15.p:0.21.q:0.12
******************************
初值與真實值差異較小時
初值 pi:0.3000.p:0.6000.q:0.3000
迭代次數: 4
估計 pi:0.3168.p:0.6606.q:0.3573
差異數 pi:0.03.p:0.00.q:0.03

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章