機器學習通俗入門-使用梯度下降法求解二分問題

回顧

在前面的文章中介紹了使用梯度下降法解決迴歸問題。那麼使用如何解決二分問題呢?

問題

現在有這麼一個數據集 D=x,y 其中 x 是觀測到的數據,y 是所屬分類。我們想通過建立一個模型,給出x就能得到它的分類信息。

數據集

我們使用matlab造一個數據集出來。

% 創建測試數據
x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];
x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];
hold off;
plot(x1(:,1),x1(:,2),'or');
hold on;
plot(x2(:,1),x2(:,2),'ob');
y1 = zeros(40,1);
y2 = ones(40,1);
x = [x1;x2]; 
y = [y1;y2];
hold off;

數據的值爲:

D =

    1.5764    1.8386         0
    1.1573    4.6626         0
    3.1160    3.1710         0
    3.3631    2.4825         0
    3.1122    3.9442         0
    2.3633    1.1272         0
    2.4668    4.7144         0
    1.2313    2.2046         0
    4.0816    3.4433         0
    3.6820    6.8999         0
    3.3439    2.3597         0
    2.8488    5.1085         0
    4.0656    3.4616         0
    2.2035    7.1164         0
    3.8557    4.4835         0
    4.3226    2.2885         0
    6.0615    0.7585         0
    2.4069    4.0333         0
    2.9405    0.0246         0
    3.6190   -0.3831         0
    3.4319    5.0499         0
    4.7547    0.5282         0
    1.0382    2.2332         0
    3.7447    4.8796         0
    3.9460   -0.3395         0
    4.8710   -0.9450         0
    4.3029    3.2495         0
    2.0460    3.8779         0
    2.2227   -0.1015         0
    2.0860    2.4868         0
    3.1852    5.5069         0
    2.5845    1.6710         0
    2.5240    1.0334         0
    0.8195    4.4851         0
    3.5216    3.3754         0
    4.1698    2.6556         0
    4.9099    1.5497         0
    3.7109    6.1300         0
    3.7440    0.5363         0
    2.5802   -0.9823         0
    6.6795    7.8466    1.0000
    6.9358    5.0137    1.0000
    6.2642    8.0141    1.0000
    8.8116    4.0538    1.0000
    7.4373    7.0817    1.0000
    6.0725    4.4535    1.0000
    5.1592    2.6908    1.0000
    6.8279    5.9248    1.0000
    8.1433    5.4880    1.0000
    6.8186    5.3189    1.0000
    8.0123    6.7615    1.0000
    6.9431    6.6514    1.0000
    6.4406    4.6153    1.0000
    6.6724    6.3087    1.0000
    6.9785    3.4098    1.0000
    6.5140    7.9431    1.0000
    7.6444    2.4133    1.0000
    8.2008    3.7157    1.0000
    7.4970    3.8444    1.0000
    5.9602    5.5373    1.0000
    7.4493    6.2110    1.0000
    6.6591    6.8626    1.0000
    6.7736    8.0535    1.0000
    5.1214    6.4982    1.0000
    7.6140    2.6426    1.0000
    7.2352    5.1972    1.0000
    7.9788    7.0529    1.0000
    7.4953   10.8854    1.0000
    6.0268    5.5639    1.0000
    7.3062    6.3531    1.0000
    7.9649    5.8976    1.0000
    6.5187    8.1833    1.0000
    6.7009    7.4581    1.0000
    4.7819    6.3288    1.0000
    5.7821    5.8203    1.0000
    7.6720    7.9396    1.0000
    9.4654    1.2340    1.0000
    6.1818    9.1185    1.0000
    7.9709    8.0687    1.0000
    8.7071    7.0849    1.0000

爲了能直觀表示這些數據,我們根據y值使用不同顏色將其繪出。
這裏寫圖片描述

將x,y,z畫爲立體圖,效果如下:
這裏寫圖片描述

模型

在上一篇文章中,我們使用了 y=wx+b 這種函數來進行線性擬合,這個問題的y值不是0,就是1。使用上述函數很難快速逼近0-1。爲了解決這個問題,我們給出一個激活函數。

f(z)=11+ez

使用matlab繪製出該函數曲線:

>> x = -10:0.1:10;
>> y = 1 ./ (1+ exp(-x));
>> plot(x,y)

sigmoid函數
這個函數叫做Logistic函數,或Sigmoid函數。可以看到他在x趨向於無窮小時,逼近0 ,在x趨向於無窮大時逼近與0 。它的因變量取值爲(0,1) ,與概率的取值範圍相同。在自變量靠近0點的時候,y值的變化比較陡峭,這樣它就會對x的變化敏感。

我們讓上述公式中的 z=wTx 就構成了我們的模型

f(x)=11+ewTx

其中w爲參數列向量, x爲樣本列向量。
也許有人問, z=wTx+b 的那個b去哪裏了。我們將x變換爲增廣向量。 即 x=[1,x1,x2,x3,x4,...,xk] 這種形式。那麼w的第0個分量 w0 就是原本公式中的b。

z=wTx=w0+w1x1+w2x2+...+wkxk

參數估計

給出上述數據集,估計出w 就是參數估計。首先構造一個損失函數。因爲f(x) 近似於一個概率,所以我們使用概率的參數估計方法,就是最大似然估計。

最大似然估計是,給定一組樣本,找到一個參數使得使用該參數時,樣本出現的概率最大。

使用 x(i)y(i) 表示樣本i 的數據和類別。

我們令

p(y(i)=1|w)=f(x(i))

最大似然函數

L(w)=ln(x(i),y(i))Dp(y(i)=1|w)y(i)(1p(y(i)=1|w))1y(i)=lnf(x(i))y(i)(1f(x(i)))1y(i)=y(i)logf(x(i))+(1y(i))(1f(x(i))

這個概率越大越好,那麼設損失函數爲它的相反數。

l(w)=L(w)

爲了找到讓損失函數最小的w,我們仍然使用梯度下降法。具體做法是,首先對 l(w) 求導,然後在迭代時使用 下面的公式

w=wa˙l(w)w

求導

爲了對這個公式求導,我們先對f(z)和z(x)進行求導。

f(z)=11+ez

fz=1(1+ez)2ez˙(1)=ez(1+ez)2=1+ez1(1+ez)2=11+ez˙1(1+ez)2=f(z)f(z)2=f(z)(1f(z))

z=wTx

zw=x

可以看到,f(z) 的導數就是 f(z)(1f(z))

l(x) 求w的偏導數。爲了簡化過程,我們省略所有公式中所有上標

lw=lffzzw

=(yf1y1f)fzzw

=yff(1f)fzzw

=yff(1f)f(1f)x=(fy)x

這個公式是如此的熟悉,和線性迴歸的公式很相似。線性迴歸中f=wx
,這裏 f=sigmoid(wx) ,但形式是相同的,都是

(估計值 - 真實值)*自變量,然後所有樣本的這個值求和

迭代公式爲:

w=wa(fy)x

其中f-y表示估計值和真實值的差,這個差越大說明w需要調整的越多,它也參數需要調整的量成正比。x越大需要調整的越多。

python實現

http://blog.csdn.net/taiji1985/article/details/51250860

matlab實現

function [w,f,c,accury] = lr_predict(x,y)
    [n,k] = size(x); % n爲樣本數 ,k爲維度

    %增廣x
    x = [ones(n,1) x ];   
    %隨機生成w初值
    w = rand(1,k+1); % 弄一個橫向量方便 求 wx
    olde = 0;
    e = 1;
    eps = 0.0001 ;
    rate = 0.01;
    i = 0;  %計數器
    while true
       z = w*x';
       f = 1./(1+exp(-z));
       e = sum(abs(f-y'))/n; %誤差 '
       w = w - rate*(f-y')*x; % 更新權值
       d = abs(olde  -e); %計算兩次誤差的變化
       fprintf('%d iter e = %f , d = %f \n',i,e,d);  
       if d < eps
          break; 
       end
       olde =e;       
       i= i+1;
    end
    c = f>0.5;
    accury = (n - sum(abs(c-y')))/n; %準確率

    fprintf('accury is %f ',accury);



% 創建測試數據

seed = 333;
rand('seed',seed)
x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];
x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];
fig_on = 1;
if fig_on
    hold off;
    plot(x1(:,1),x1(:,2),'or');
    hold on;
    plot(x2(:,1),x2(:,2),'ob');
end
y1 = zeros(40,1);
y2 = ones(40,1);
x = [x1;x2]; 
y = [y1;y2];
if fig_on
    hold off;
    figure(2);
     plot3(x1(:,1),x1(:,2),y1,'or');
    hold on;
     plot3(x2(:,1),x2(:,2),y2,'ob');
end
%surf(x(:,1),x(:,2),y);

% 進行分類
[w,f,c,a] = lr_predict(x,y)
if fig_on
    hold off;
    figure(3);
    plot(x1(:,1),x1(:,2),'or');
    hold on;
    plot(x2(:,1),x2(:,2),'ob');
    xe = x(find(c ~= y'),:)
    plot(xe(:,1),xe(:,2),'sm','MarkerSize',10,'LineWidth',2);
    %mm = min(x);
    %mx = max(x);
    %xx =  mm(1):0.1:mx(1); 
    %yy = (w(1)+w(2)*xx)/w(3);
    %plot(xx,yy);
end

實驗結果

這裏寫圖片描述

紫色爲錯分。

matlab輸出結果爲

0 iter e = 0.476728 , d = 0.476728 
1 iter e = 0.527496 , d = 0.050768 
2 iter e = 0.499542 , d = 0.027954 
3 iter e = 0.485917 , d = 0.013625 
4 iter e = 0.543169 , d = 0.057251 
5 iter e = 0.499576 , d = 0.043593 
6 iter e = 0.488516 , d = 0.011060 
7 iter e = 0.553108 , d = 0.064592 
8 iter e = 0.499505 , d = 0.053603 
9 iter e = 0.487229 , d = 0.012276 
10 iter e = 0.544575 , d = 0.057346 
11 iter e = 0.499370 , d = 0.045205 
12 iter e = 0.484179 , d = 0.015191 
13 iter e = 0.532477 , d = 0.048298 
14 iter e = 0.499217 , d = 0.033260 
15 iter e = 0.481059 , d = 0.018158 
16 iter e = 0.520977 , d = 0.039918 
17 iter e = 0.498982 , d = 0.021995 
18 iter e = 0.476170 , d = 0.022812 
19 iter e = 0.510895 , d = 0.034724 
20 iter e = 0.498714 , d = 0.012181 
21 iter e = 0.471163 , d = 0.027551 
22 iter e = 0.499601 , d = 0.028438 
23 iter e = 0.498290 , d = 0.001311 
24 iter e = 0.463343 , d = 0.034947 
25 iter e = 0.492636 , d = 0.029293 
26 iter e = 0.497855 , d = 0.005219 
27 iter e = 0.456635 , d = 0.041220 
28 iter e = 0.479775 , d = 0.023140 
29 iter e = 0.497030 , d = 0.017256 
30 iter e = 0.443815 , d = 0.053215 
31 iter e = 0.478942 , d = 0.035127 
32 iter e = 0.496471 , d = 0.017529 
33 iter e = 0.438143 , d = 0.058328 
34 iter e = 0.456484 , d = 0.018341 
35 iter e = 0.494499 , d = 0.038015 
36 iter e = 0.412978 , d = 0.081521 
37 iter e = 0.473886 , d = 0.060908 
38 iter e = 0.494617 , d = 0.020731 
39 iter e = 0.420671 , d = 0.073946 
40 iter e = 0.412639 , d = 0.008031 
41 iter e = 0.487323 , d = 0.074684 
42 iter e = 0.345426 , d = 0.141897 
43 iter e = 0.482066 , d = 0.136640 
44 iter e = 0.492669 , d = 0.010603 
45 iter e = 0.409059 , d = 0.083610 
46 iter e = 0.329026 , d = 0.080033 
47 iter e = 0.459020 , d = 0.129994 
48 iter e = 0.160198 , d = 0.298822 
49 iter e = 0.323109 , d = 0.162911 
50 iter e = 0.453148 , d = 0.130038 
51 iter e = 0.148804 , d = 0.304344 
52 iter e = 0.277777 , d = 0.128973 
53 iter e = 0.423044 , d = 0.145267 
54 iter e = 0.133613 , d = 0.289431 
55 iter e = 0.173894 , d = 0.040281 
56 iter e = 0.347993 , d = 0.174099 
57 iter e = 0.457374 , d = 0.109381 
58 iter e = 0.199874 , d = 0.257501 
59 iter e = 0.378111 , d = 0.178237 
60 iter e = 0.464773 , d = 0.086662 
61 iter e = 0.258071 , d = 0.206702 
62 iter e = 0.402756 , d = 0.144686 
63 iter e = 0.468620 , d = 0.065864 
64 iter e = 0.292597 , d = 0.176024 
65 iter e = 0.368138 , d = 0.075541 
66 iter e = 0.453801 , d = 0.085663 
67 iter e = 0.229217 , d = 0.224584 
68 iter e = 0.336790 , d = 0.107573 
69 iter e = 0.436250 , d = 0.099459 
70 iter e = 0.163286 , d = 0.272964 
71 iter e = 0.232334 , d = 0.069048 
72 iter e = 0.356480 , d = 0.124146 
73 iter e = 0.150146 , d = 0.206334 
74 iter e = 0.236759 , d = 0.086612 
75 iter e = 0.299827 , d = 0.063068 
76 iter e = 0.407203 , d = 0.107377 
77 iter e = 0.099765 , d = 0.307438 
78 iter e = 0.109024 , d = 0.009258 
79 iter e = 0.137674 , d = 0.028651 
80 iter e = 0.166169 , d = 0.028494 
81 iter e = 0.263474 , d = 0.097305 
82 iter e = 0.267471 , d = 0.003997 
83 iter e = 0.377094 , d = 0.109624 
84 iter e = 0.082407 , d = 0.294687 
85 iter e = 0.082223 , d = 0.000184 
86 iter e = 0.082569 , d = 0.000347 
87 iter e = 0.082813 , d = 0.000244 
88 iter e = 0.083618 , d = 0.000805 
89 iter e = 0.084850 , d = 0.001232 
90 iter e = 0.086432 , d = 0.001582 
91 iter e = 0.090219 , d = 0.003787 
92 iter e = 0.093099 , d = 0.002880 
93 iter e = 0.103860 , d = 0.010761 
94 iter e = 0.109086 , d = 0.005226 
95 iter e = 0.139819 , d = 0.030733 
96 iter e = 0.151821 , d = 0.012001 
97 iter e = 0.231801 , d = 0.079980 
98 iter e = 0.230153 , d = 0.001648 
99 iter e = 0.335810 , d = 0.105658 
100 iter e = 0.108298 , d = 0.227512 
101 iter e = 0.136344 , d = 0.028046 
102 iter e = 0.135687 , d = 0.000657 
103 iter e = 0.196323 , d = 0.060637 
104 iter e = 0.191395 , d = 0.004928 
105 iter e = 0.288650 , d = 0.097254 
106 iter e = 0.168363 , d = 0.120287 
107 iter e = 0.253856 , d = 0.085494 
108 iter e = 0.190190 , d = 0.063667 
109 iter e = 0.283906 , d = 0.093716 
110 iter e = 0.158495 , d = 0.125411 
111 iter e = 0.234220 , d = 0.075725 
112 iter e = 0.178957 , d = 0.055263 
113 iter e = 0.265306 , d = 0.086349 
114 iter e = 0.161863 , d = 0.103443 
115 iter e = 0.236991 , d = 0.075128 
116 iter e = 0.166209 , d = 0.070782 
117 iter e = 0.242725 , d = 0.076516 
118 iter e = 0.159634 , d = 0.083091 
119 iter e = 0.230343 , d = 0.070709 
120 iter e = 0.155406 , d = 0.074936 
121 iter e = 0.221664 , d = 0.066258 
122 iter e = 0.149961 , d = 0.071703 
123 iter e = 0.210649 , d = 0.060688 
124 iter e = 0.143460 , d = 0.067189 
125 iter e = 0.197445 , d = 0.053984 
126 iter e = 0.135411 , d = 0.062034 
127 iter e = 0.181018 , d = 0.045607 
128 iter e = 0.125077 , d = 0.055941 
129 iter e = 0.159931 , d = 0.034854 
130 iter e = 0.111675 , d = 0.048256 
131 iter e = 0.133207 , d = 0.021532 
132 iter e = 0.095523 , d = 0.037684 
133 iter e = 0.103705 , d = 0.008182 
134 iter e = 0.080177 , d = 0.023529 
135 iter e = 0.080751 , d = 0.000574 
136 iter e = 0.070453 , d = 0.010298 
137 iter e = 0.069785 , d = 0.000668 
138 iter e = 0.066274 , d = 0.003512 
139 iter e = 0.066042 , d = 0.000232 
140 iter e = 0.064806 , d = 0.001236 
141 iter e = 0.064808 , d = 0.000002 
accury is 0.975000 
w =

  -10.0562    1.5878    0.4737


f =

  Columns 1 through 8

    0.0014    0.0028    0.0314    0.0332    0.0451    0.0035    0.0238    0.0009

  Columns 9 through 16

    0.1501    0.3391    0.0304    0.0520    0.1480    0.0497    0.1703    0.1283

  Columns 17 through 24

    0.5347    0.0155    0.0051    0.0124    0.1204    0.1101    0.0007    0.1727

  Columns 25 through 32

    0.0213    0.0672    0.1867    0.0081    0.0015    0.0043    0.1033    0.0065

  Columns 33 through 40

    0.0043    0.0015    0.0646    0.1212    0.2088    0.2685    0.0237    0.0017

  Columns 41 through 48

    0.9905    0.9750    0.9831    0.9980    0.9959    0.8795    0.4102    0.9809

  Columns 49 through 56

    0.9971    0.9740    0.9981    0.9888    0.9348    0.9797    0.9496    0.9882

  Columns 57 through 64

    0.9712    0.9937    0.9818    0.9124    0.9939    0.9842    0.9926    0.8119

  Columns 65 through 72

    0.9730    0.9858    0.9983    0.9994    0.9216    0.9928    0.9969    0.9896

  Columns 73 through 80

    0.9889    0.6962    0.8998    0.9982    0.9972    0.9887    0.9989    0.9995


c =

  Columns 1 through 14

     0     0     0     0     0     0     0     0     0     0     0     0     0     0

  Columns 15 through 28

     0     0     1     0     0     0     0     0     0     0     0     0     0     0

  Columns 29 through 42

     0     0     0     0     0     0     0     0     0     0     0     0     1     1

  Columns 43 through 56

     1     1     1     1     0     1     1     1     1     1     1     1     1     1

  Columns 57 through 70

     1     1     1     1     1     1     1     1     1     1     1     1     1     1

  Columns 71 through 80

     1     1     1     1     1     1     1     1     1     1


a =

    0.9750


xe =

    6.0615    0.7585
    5.1592    2.6908
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章