機器學習(九):鳶尾花-邏輯迴歸

注:基於現有案例教程

鳶尾花數據來源於seaborn中自帶的數據集,很多類似的都會自帶這個數據集

代碼如下:

import pymc3 as pm
import pandas as pd
import scipy.stats as stats
import theano.tensor as tt
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
if __name__=='__main__':
    #載入數據集
    iris = sns.load_dataset("iris")
    #花萼長度這一特徵(自變量)來區分 setosa 和 versicolor 這兩個種類
    df = iris.query("species == ('setosa', 'versicolor')")
    y_0 = pd.Categorical(df['species']).codes
    x_n = 'sepal_length'
    x_0 = df[x_n].values
    #上面已經將數據已經表示成了合適的格式,用PyMC3來建模
    #用with語句定義了一個上下文管理器 context manager ,並定義了一個新的模型對象,這個對象是模型中隨機變量的容器
    with pm.Model() as model_0:
        #上下文中定義了兩個具有正態分佈先驗的隨機性隨機變量
        alpha = pm.Normal('alpha', mu=0, sd=10)
        beta = pm.Normal('beta', mu=0, sd=10)
        #兩個確定變量:theta 和 bd。theta 是對變量 mu 應用邏輯函數之後的值,bd 是一個有邊界的值,用於確定分類結果
        mu = alpha + pm.math.dot(x_0, beta)
        theta = pm.Deterministic('theta', 1 / (1 + pm.math.exp(-mu)))
        bd = pm.Deterministic('bd', -alpha / beta)
        #二元分類->伯努利可能性
        yl = pm.Bernoulli('yl', theta, observed=y_0)
        start = pm.find_MAP()
        step = pm.NUTS()
        #step 參數指定特定的採樣器(迭代器)來替換默認的迭代器NUTS
        trace_0 = pm.sampling.sample(500, step, start)  # 本地運行時,推薦將迭代次數設置爲大於 1000 次,這裏是500+默認500=1000
    varnames = ['alpha', 'beta', 'bd']
    pm.traceplot(trace_0, varnames)
    theta = trace_0['theta'].mean(axis=0)
    idx = np.argsort(x_0)
    plt.plot(x_0[idx], theta[idx], color='b', lw=3);
    #繪製一條橫跨當前圖表的垂直/水平輔助線,x:恆座標,
    plt.axvline(trace_0['bd'].mean(), ymax=1, color='r')
    bd_hpd = pm.hpd(trace_0['bd'])
    #水平防線,x1和x2之間填充
    plt.fill_betweenx([0, 1], bd_hpd[0], bd_hpd[1], color='r', alpha=0.5)
    #繪製折線圖
    plt.plot(x_0, y_0, 'o', color='k')
    theta_hpd = pm.hpd(trace_0['theta'])[idx]
    #y1和y2之間填充
    plt.fill_between(x_0[idx], theta_hpd[:, 0], theta_hpd[:, 1], color='b', alpha=0.5)
    plt.xlabel(x_n, fontsize=16)
    plt.ylabel(r'$\theta$', rotation=0, fontsize=16)
    plt.show()

運行結果:

這張圖表示了花萼長度與花的種類(setosa = 0, versicolor =1)之間的關係。藍色的 S 型曲線表示 theta 的均值,這條線可以解釋爲:在知道花萼長度的情況下花的種類是 versicolor 的概率。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章