統計學習方法之邏輯迴歸(Logistic Regression)

邏輯迴歸之所以叫邏輯是因爲他用到了邏輯分佈:
這裏寫圖片描述
圖形如下:
這裏寫圖片描述
還是按照老樣子,根據李航老師的統計學習方法三部分進行學習。
1 模型
假設輸入爲任意範圍內的屬性值,輸出爲0-1之間的概率。給定一個閾值,當概率大於該閾值時,Y = 1,否則Y= 0。(在等於閾值部分隨意設定,畢竟一點之差不算差)
利用邏輯分佈表示之:
這裏寫圖片描述
w*x表示,對不同屬性的權值做了一個設定,b表示偏移量。也可以以x0 = 1, w0 = b的方式寫到w*x上。
這裏寫圖片描述
在這裏注意一點,林軒田老師的視頻中Y用的是-1或1,這樣可以用到邏輯函數中心對稱的性質:

  1-logit(x) = logit(-x)

這個性質在以後最大似然估計和梯度下降時有重大簡化作用。當然,本文先使用0,1解決,在用-1,1解決一次。
李航老師書中提到了機率(odds)的概念,但是在後期的解答過程中並沒有涉及到相關問題,我就不再記下來了。應該是解釋爲什麼輸出Y是一個離散值,但卻叫回歸的原因。

2 策略
爲了得到最合適的w,我們應當採取合適的策略進行學習。邏輯迴歸模型採用的是最大似然估計的方式。
似然函數:
這裏寫圖片描述
那麼問題來了,什麼是似然函數?

設總體X的分佈形式P(x;w)已知,其中w屬於W未知,x爲X的樣本值,稱L(w) = 求積i=1->n(P(xi, w))爲參數w的似然函數。
意義:在已知樣本結果的前提下,來估計滿足這些樣本分佈的參數w,選擇可能性最大的w作爲真實的w。

這是《概率論與數理統計》上大概的意思,爲什麼要用乘積?因爲要使得訓練集裏每個樣本的估計值都最大才行。爲方便運算,進行對數操作。
這裏寫圖片描述

3 算法
由上述可知,爲了找到合適的w,要求的w使得L(w)最大,林軒田老師採用取負數的形式,以此利用梯度下降發來求最小值。《概率論與數理統計》中直接對L(w)求導,然後直接得到結果。我們採用林軒田老師的方式。
這裏寫圖片描述
此爲每次更新的部分。
梯度下降:
repeat {
這裏寫圖片描述
}until convergence

matlab代碼實現:

function W = LogisticRegression(X, Y, alpha)
    n = length(Y)%the number of samples
    m = size(X, 2) % the number of attributes
    W = zeros(m, 1)%init

    %sum保證每個屬性的迭代都等於0
    while sum(Cost(X, Y, W) ~= zeros(length(Y), 1))
        W = W - alpha*X'*Cost(X, Y, W)
    end
end


%計算預測值,返回一個向量,
%分別爲對該樣本集中的單個樣本類別爲1的預測概率
function yy = PI(W, X)
    %yy = exp(X*W)./(1+exp(X*W))
    expX = exp(-X*W)
    deno = 1 + expX
    yy = 1./(deno)
    end

%代價函數,我們需要迭代最終爲零,
%得到的是一個m*1的矩陣,用來對w進行更新
%***在判斷中用了一個abs(),因爲我的例子中出現了極端情況,如預測值爲0.9990,然後死循環了。爲避免之故求了一個範圍,如有更好方法求解答!***
function ww = Cost(X, Y, W)
    if(abs(PI(W, X) - Y)) <= 0.001
        ww = 0
    else
        ww = PI(W, X) - Y
    end 
end

測試函數:

load data2.txt
X = data2(:, 1:2)
Y = data2(:, 5)
X1 = X(1:49, :)
X2 = X(51:98, :)

%畫出兩種樣本
plot(X1(:, 1),X1(:, 2),'X')
hold on
plot(X2(:, 1),X2(:, 2),'O')
input = [ones(length(Y),1), X]
W = LogisticRegression(input, Y, 0.3)

%輸入測試集合
testX = [5 4; 6.2 3.2 ; 5.7 3.3; 4.2 3.5; 5 2.7]
testInput = [ones(size(testX, 1),1) testX]
testY = 1./(1+exp(-testInput*W))

for i = 1:length(testY)
    if testY(i) > 0.5
        plot(testX(i,1), testX(i, 2), 's');
    else
        plot(testX(i,1), testX(i, 2), 'p');
    end
end

結果

學習結果:
這裏寫圖片描述
測試用例及預測:
這裏寫圖片描述
可視化顯示
這裏寫圖片描述
可見,學習成功啦!^_^

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章