统计学习方法之逻辑回归(Logistic Regression)

逻辑回归之所以叫逻辑是因为他用到了逻辑分布:
这里写图片描述
图形如下:
这里写图片描述
还是按照老样子,根据李航老师的统计学习方法三部分进行学习。
1 模型
假设输入为任意范围内的属性值,输出为0-1之间的概率。给定一个阈值,当概率大于该阈值时,Y = 1,否则Y= 0。(在等于阈值部分随意设定,毕竟一点之差不算差)
利用逻辑分布表示之:
这里写图片描述
w*x表示,对不同属性的权值做了一个设定,b表示偏移量。也可以以x0 = 1, w0 = b的方式写到w*x上。
这里写图片描述
在这里注意一点,林轩田老师的视频中Y用的是-1或1,这样可以用到逻辑函数中心对称的性质:

  1-logit(x) = logit(-x)

这个性质在以后最大似然估计和梯度下降时有重大简化作用。当然,本文先使用0,1解决,在用-1,1解决一次。
李航老师书中提到了机率(odds)的概念,但是在后期的解答过程中并没有涉及到相关问题,我就不再记下来了。应该是解释为什么输出Y是一个离散值,但却叫回归的原因。

2 策略
为了得到最合适的w,我们应当采取合适的策略进行学习。逻辑回归模型采用的是最大似然估计的方式。
似然函数:
这里写图片描述
那么问题来了,什么是似然函数?

设总体X的分布形式P(x;w)已知,其中w属于W未知,x为X的样本值,称L(w) = 求积i=1->n(P(xi, w))为参数w的似然函数。
意义:在已知样本结果的前提下,来估计满足这些样本分布的参数w,选择可能性最大的w作为真实的w。

这是《概率论与数理统计》上大概的意思,为什么要用乘积?因为要使得训练集里每个样本的估计值都最大才行。为方便运算,进行对数操作。
这里写图片描述

3 算法
由上述可知,为了找到合适的w,要求的w使得L(w)最大,林轩田老师采用取负数的形式,以此利用梯度下降发来求最小值。《概率论与数理统计》中直接对L(w)求导,然后直接得到结果。我们采用林轩田老师的方式。
这里写图片描述
此为每次更新的部分。
梯度下降:
repeat {
这里写图片描述
}until convergence

matlab代码实现:

function W = LogisticRegression(X, Y, alpha)
    n = length(Y)%the number of samples
    m = size(X, 2) % the number of attributes
    W = zeros(m, 1)%init

    %sum保证每个属性的迭代都等于0
    while sum(Cost(X, Y, W) ~= zeros(length(Y), 1))
        W = W - alpha*X'*Cost(X, Y, W)
    end
end


%计算预测值,返回一个向量,
%分别为对该样本集中的单个样本类别为1的预测概率
function yy = PI(W, X)
    %yy = exp(X*W)./(1+exp(X*W))
    expX = exp(-X*W)
    deno = 1 + expX
    yy = 1./(deno)
    end

%代价函数,我们需要迭代最终为零,
%得到的是一个m*1的矩阵,用来对w进行更新
%***在判断中用了一个abs(),因为我的例子中出现了极端情况,如预测值为0.9990,然后死循环了。为避免之故求了一个范围,如有更好方法求解答!***
function ww = Cost(X, Y, W)
    if(abs(PI(W, X) - Y)) <= 0.001
        ww = 0
    else
        ww = PI(W, X) - Y
    end 
end

测试函数:

load data2.txt
X = data2(:, 1:2)
Y = data2(:, 5)
X1 = X(1:49, :)
X2 = X(51:98, :)

%画出两种样本
plot(X1(:, 1),X1(:, 2),'X')
hold on
plot(X2(:, 1),X2(:, 2),'O')
input = [ones(length(Y),1), X]
W = LogisticRegression(input, Y, 0.3)

%输入测试集合
testX = [5 4; 6.2 3.2 ; 5.7 3.3; 4.2 3.5; 5 2.7]
testInput = [ones(size(testX, 1),1) testX]
testY = 1./(1+exp(-testInput*W))

for i = 1:length(testY)
    if testY(i) > 0.5
        plot(testX(i,1), testX(i, 2), 's');
    else
        plot(testX(i,1), testX(i, 2), 'p');
    end
end

结果

学习结果:
这里写图片描述
测试用例及预测:
这里写图片描述
可视化显示
这里写图片描述
可见,学习成功啦!^_^

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章