回顧
在前面的文章中介紹了使用梯度下降法解決迴歸問題。那麼使用如何解決二分問題呢?
問題
現在有這麼一個數據集
數據集
我們使用matlab造一個數據集出來。
% 創建測試數據
x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];
x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];
hold off;
plot(x1(:,1),x1(:,2),'or');
hold on;
plot(x2(:,1),x2(:,2),'ob');
y1 = zeros(40,1);
y2 = ones(40,1);
x = [x1;x2];
y = [y1;y2];
hold off;
數據的值爲:
D =
1.5764 1.8386 0
1.1573 4.6626 0
3.1160 3.1710 0
3.3631 2.4825 0
3.1122 3.9442 0
2.3633 1.1272 0
2.4668 4.7144 0
1.2313 2.2046 0
4.0816 3.4433 0
3.6820 6.8999 0
3.3439 2.3597 0
2.8488 5.1085 0
4.0656 3.4616 0
2.2035 7.1164 0
3.8557 4.4835 0
4.3226 2.2885 0
6.0615 0.7585 0
2.4069 4.0333 0
2.9405 0.0246 0
3.6190 -0.3831 0
3.4319 5.0499 0
4.7547 0.5282 0
1.0382 2.2332 0
3.7447 4.8796 0
3.9460 -0.3395 0
4.8710 -0.9450 0
4.3029 3.2495 0
2.0460 3.8779 0
2.2227 -0.1015 0
2.0860 2.4868 0
3.1852 5.5069 0
2.5845 1.6710 0
2.5240 1.0334 0
0.8195 4.4851 0
3.5216 3.3754 0
4.1698 2.6556 0
4.9099 1.5497 0
3.7109 6.1300 0
3.7440 0.5363 0
2.5802 -0.9823 0
6.6795 7.8466 1.0000
6.9358 5.0137 1.0000
6.2642 8.0141 1.0000
8.8116 4.0538 1.0000
7.4373 7.0817 1.0000
6.0725 4.4535 1.0000
5.1592 2.6908 1.0000
6.8279 5.9248 1.0000
8.1433 5.4880 1.0000
6.8186 5.3189 1.0000
8.0123 6.7615 1.0000
6.9431 6.6514 1.0000
6.4406 4.6153 1.0000
6.6724 6.3087 1.0000
6.9785 3.4098 1.0000
6.5140 7.9431 1.0000
7.6444 2.4133 1.0000
8.2008 3.7157 1.0000
7.4970 3.8444 1.0000
5.9602 5.5373 1.0000
7.4493 6.2110 1.0000
6.6591 6.8626 1.0000
6.7736 8.0535 1.0000
5.1214 6.4982 1.0000
7.6140 2.6426 1.0000
7.2352 5.1972 1.0000
7.9788 7.0529 1.0000
7.4953 10.8854 1.0000
6.0268 5.5639 1.0000
7.3062 6.3531 1.0000
7.9649 5.8976 1.0000
6.5187 8.1833 1.0000
6.7009 7.4581 1.0000
4.7819 6.3288 1.0000
5.7821 5.8203 1.0000
7.6720 7.9396 1.0000
9.4654 1.2340 1.0000
6.1818 9.1185 1.0000
7.9709 8.0687 1.0000
8.7071 7.0849 1.0000
爲了能直觀表示這些數據,我們根據y值使用不同顏色將其繪出。
將x,y,z畫爲立體圖,效果如下:
模型
在上一篇文章中,我們使用了
使用matlab繪製出該函數曲線:
>> x = -10:0.1:10;
>> y = 1 ./ (1+ exp(-x));
>> plot(x,y)
這個函數叫做Logistic函數,或Sigmoid函數。可以看到他在x趨向於無窮小時,逼近0 ,在x趨向於無窮大時逼近與0 。它的因變量取值爲(0,1) ,與概率的取值範圍相同。在自變量靠近0點的時候,y值的變化比較陡峭,這樣它就會對x的變化敏感。
我們讓上述公式中的
其中w爲參數列向量, x爲樣本列向量。
也許有人問,
參數估計
給出上述數據集,估計出
最大似然估計是,給定一組樣本,找到一個參數使得使用該參數時,樣本出現的概率最大。
使用
我們令
最大似然函數
這個概率越大越好,那麼設損失函數爲它的相反數。
爲了找到讓損失函數最小的w,我們仍然使用梯度下降法。具體做法是,首先對
求導
爲了對這個公式求導,我們先對f(z)和z(x)進行求導。
可以看到,
對
這個公式是如此的熟悉,和線性迴歸的公式很相似。線性迴歸中
,這裏
(估計值 - 真實值)*自變量,然後所有樣本的這個值求和
迭代公式爲:
其中f-y表示估計值和真實值的差,這個差越大說明w需要調整的越多,它也參數需要調整的量成正比。x越大需要調整的越多。
python實現
http://blog.csdn.net/taiji1985/article/details/51250860
matlab實現
function [w,f,c,accury] = lr_predict(x,y)
[n,k] = size(x); % n爲樣本數 ,k爲維度
%增廣x
x = [ones(n,1) x ];
%隨機生成w初值
w = rand(1,k+1); % 弄一個橫向量方便 求 wx
olde = 0;
e = 1;
eps = 0.0001 ;
rate = 0.01;
i = 0; %計數器
while true
z = w*x';
f = 1./(1+exp(-z));
e = sum(abs(f-y'))/n; %誤差 '
w = w - rate*(f-y')*x; % 更新權值
d = abs(olde -e); %計算兩次誤差的變化
fprintf('%d iter e = %f , d = %f \n',i,e,d);
if d < eps
break;
end
olde =e;
i= i+1;
end
c = f>0.5;
accury = (n - sum(abs(c-y')))/n; %準確率
fprintf('accury is %f ',accury);
% 創建測試數據
seed = 333;
rand('seed',seed)
x1 = [normrnd(3,1,40,1) normrnd(3,2,40,1)];
x2 = [normrnd(7,1,40,1) normrnd(6,2,40,1)];
fig_on = 1;
if fig_on
hold off;
plot(x1(:,1),x1(:,2),'or');
hold on;
plot(x2(:,1),x2(:,2),'ob');
end
y1 = zeros(40,1);
y2 = ones(40,1);
x = [x1;x2];
y = [y1;y2];
if fig_on
hold off;
figure(2);
plot3(x1(:,1),x1(:,2),y1,'or');
hold on;
plot3(x2(:,1),x2(:,2),y2,'ob');
end
%surf(x(:,1),x(:,2),y);
% 進行分類
[w,f,c,a] = lr_predict(x,y)
if fig_on
hold off;
figure(3);
plot(x1(:,1),x1(:,2),'or');
hold on;
plot(x2(:,1),x2(:,2),'ob');
xe = x(find(c ~= y'),:)
plot(xe(:,1),xe(:,2),'sm','MarkerSize',10,'LineWidth',2);
%mm = min(x);
%mx = max(x);
%xx = mm(1):0.1:mx(1);
%yy = (w(1)+w(2)*xx)/w(3);
%plot(xx,yy);
end
實驗結果
紫色爲錯分。
matlab輸出結果爲
0 iter e = 0.476728 , d = 0.476728
1 iter e = 0.527496 , d = 0.050768
2 iter e = 0.499542 , d = 0.027954
3 iter e = 0.485917 , d = 0.013625
4 iter e = 0.543169 , d = 0.057251
5 iter e = 0.499576 , d = 0.043593
6 iter e = 0.488516 , d = 0.011060
7 iter e = 0.553108 , d = 0.064592
8 iter e = 0.499505 , d = 0.053603
9 iter e = 0.487229 , d = 0.012276
10 iter e = 0.544575 , d = 0.057346
11 iter e = 0.499370 , d = 0.045205
12 iter e = 0.484179 , d = 0.015191
13 iter e = 0.532477 , d = 0.048298
14 iter e = 0.499217 , d = 0.033260
15 iter e = 0.481059 , d = 0.018158
16 iter e = 0.520977 , d = 0.039918
17 iter e = 0.498982 , d = 0.021995
18 iter e = 0.476170 , d = 0.022812
19 iter e = 0.510895 , d = 0.034724
20 iter e = 0.498714 , d = 0.012181
21 iter e = 0.471163 , d = 0.027551
22 iter e = 0.499601 , d = 0.028438
23 iter e = 0.498290 , d = 0.001311
24 iter e = 0.463343 , d = 0.034947
25 iter e = 0.492636 , d = 0.029293
26 iter e = 0.497855 , d = 0.005219
27 iter e = 0.456635 , d = 0.041220
28 iter e = 0.479775 , d = 0.023140
29 iter e = 0.497030 , d = 0.017256
30 iter e = 0.443815 , d = 0.053215
31 iter e = 0.478942 , d = 0.035127
32 iter e = 0.496471 , d = 0.017529
33 iter e = 0.438143 , d = 0.058328
34 iter e = 0.456484 , d = 0.018341
35 iter e = 0.494499 , d = 0.038015
36 iter e = 0.412978 , d = 0.081521
37 iter e = 0.473886 , d = 0.060908
38 iter e = 0.494617 , d = 0.020731
39 iter e = 0.420671 , d = 0.073946
40 iter e = 0.412639 , d = 0.008031
41 iter e = 0.487323 , d = 0.074684
42 iter e = 0.345426 , d = 0.141897
43 iter e = 0.482066 , d = 0.136640
44 iter e = 0.492669 , d = 0.010603
45 iter e = 0.409059 , d = 0.083610
46 iter e = 0.329026 , d = 0.080033
47 iter e = 0.459020 , d = 0.129994
48 iter e = 0.160198 , d = 0.298822
49 iter e = 0.323109 , d = 0.162911
50 iter e = 0.453148 , d = 0.130038
51 iter e = 0.148804 , d = 0.304344
52 iter e = 0.277777 , d = 0.128973
53 iter e = 0.423044 , d = 0.145267
54 iter e = 0.133613 , d = 0.289431
55 iter e = 0.173894 , d = 0.040281
56 iter e = 0.347993 , d = 0.174099
57 iter e = 0.457374 , d = 0.109381
58 iter e = 0.199874 , d = 0.257501
59 iter e = 0.378111 , d = 0.178237
60 iter e = 0.464773 , d = 0.086662
61 iter e = 0.258071 , d = 0.206702
62 iter e = 0.402756 , d = 0.144686
63 iter e = 0.468620 , d = 0.065864
64 iter e = 0.292597 , d = 0.176024
65 iter e = 0.368138 , d = 0.075541
66 iter e = 0.453801 , d = 0.085663
67 iter e = 0.229217 , d = 0.224584
68 iter e = 0.336790 , d = 0.107573
69 iter e = 0.436250 , d = 0.099459
70 iter e = 0.163286 , d = 0.272964
71 iter e = 0.232334 , d = 0.069048
72 iter e = 0.356480 , d = 0.124146
73 iter e = 0.150146 , d = 0.206334
74 iter e = 0.236759 , d = 0.086612
75 iter e = 0.299827 , d = 0.063068
76 iter e = 0.407203 , d = 0.107377
77 iter e = 0.099765 , d = 0.307438
78 iter e = 0.109024 , d = 0.009258
79 iter e = 0.137674 , d = 0.028651
80 iter e = 0.166169 , d = 0.028494
81 iter e = 0.263474 , d = 0.097305
82 iter e = 0.267471 , d = 0.003997
83 iter e = 0.377094 , d = 0.109624
84 iter e = 0.082407 , d = 0.294687
85 iter e = 0.082223 , d = 0.000184
86 iter e = 0.082569 , d = 0.000347
87 iter e = 0.082813 , d = 0.000244
88 iter e = 0.083618 , d = 0.000805
89 iter e = 0.084850 , d = 0.001232
90 iter e = 0.086432 , d = 0.001582
91 iter e = 0.090219 , d = 0.003787
92 iter e = 0.093099 , d = 0.002880
93 iter e = 0.103860 , d = 0.010761
94 iter e = 0.109086 , d = 0.005226
95 iter e = 0.139819 , d = 0.030733
96 iter e = 0.151821 , d = 0.012001
97 iter e = 0.231801 , d = 0.079980
98 iter e = 0.230153 , d = 0.001648
99 iter e = 0.335810 , d = 0.105658
100 iter e = 0.108298 , d = 0.227512
101 iter e = 0.136344 , d = 0.028046
102 iter e = 0.135687 , d = 0.000657
103 iter e = 0.196323 , d = 0.060637
104 iter e = 0.191395 , d = 0.004928
105 iter e = 0.288650 , d = 0.097254
106 iter e = 0.168363 , d = 0.120287
107 iter e = 0.253856 , d = 0.085494
108 iter e = 0.190190 , d = 0.063667
109 iter e = 0.283906 , d = 0.093716
110 iter e = 0.158495 , d = 0.125411
111 iter e = 0.234220 , d = 0.075725
112 iter e = 0.178957 , d = 0.055263
113 iter e = 0.265306 , d = 0.086349
114 iter e = 0.161863 , d = 0.103443
115 iter e = 0.236991 , d = 0.075128
116 iter e = 0.166209 , d = 0.070782
117 iter e = 0.242725 , d = 0.076516
118 iter e = 0.159634 , d = 0.083091
119 iter e = 0.230343 , d = 0.070709
120 iter e = 0.155406 , d = 0.074936
121 iter e = 0.221664 , d = 0.066258
122 iter e = 0.149961 , d = 0.071703
123 iter e = 0.210649 , d = 0.060688
124 iter e = 0.143460 , d = 0.067189
125 iter e = 0.197445 , d = 0.053984
126 iter e = 0.135411 , d = 0.062034
127 iter e = 0.181018 , d = 0.045607
128 iter e = 0.125077 , d = 0.055941
129 iter e = 0.159931 , d = 0.034854
130 iter e = 0.111675 , d = 0.048256
131 iter e = 0.133207 , d = 0.021532
132 iter e = 0.095523 , d = 0.037684
133 iter e = 0.103705 , d = 0.008182
134 iter e = 0.080177 , d = 0.023529
135 iter e = 0.080751 , d = 0.000574
136 iter e = 0.070453 , d = 0.010298
137 iter e = 0.069785 , d = 0.000668
138 iter e = 0.066274 , d = 0.003512
139 iter e = 0.066042 , d = 0.000232
140 iter e = 0.064806 , d = 0.001236
141 iter e = 0.064808 , d = 0.000002
accury is 0.975000
w =
-10.0562 1.5878 0.4737
f =
Columns 1 through 8
0.0014 0.0028 0.0314 0.0332 0.0451 0.0035 0.0238 0.0009
Columns 9 through 16
0.1501 0.3391 0.0304 0.0520 0.1480 0.0497 0.1703 0.1283
Columns 17 through 24
0.5347 0.0155 0.0051 0.0124 0.1204 0.1101 0.0007 0.1727
Columns 25 through 32
0.0213 0.0672 0.1867 0.0081 0.0015 0.0043 0.1033 0.0065
Columns 33 through 40
0.0043 0.0015 0.0646 0.1212 0.2088 0.2685 0.0237 0.0017
Columns 41 through 48
0.9905 0.9750 0.9831 0.9980 0.9959 0.8795 0.4102 0.9809
Columns 49 through 56
0.9971 0.9740 0.9981 0.9888 0.9348 0.9797 0.9496 0.9882
Columns 57 through 64
0.9712 0.9937 0.9818 0.9124 0.9939 0.9842 0.9926 0.8119
Columns 65 through 72
0.9730 0.9858 0.9983 0.9994 0.9216 0.9928 0.9969 0.9896
Columns 73 through 80
0.9889 0.6962 0.8998 0.9982 0.9972 0.9887 0.9989 0.9995
c =
Columns 1 through 14
0 0 0 0 0 0 0 0 0 0 0 0 0 0
Columns 15 through 28
0 0 1 0 0 0 0 0 0 0 0 0 0 0
Columns 29 through 42
0 0 0 0 0 0 0 0 0 0 0 0 1 1
Columns 43 through 56
1 1 1 1 0 1 1 1 1 1 1 1 1 1
Columns 57 through 70
1 1 1 1 1 1 1 1 1 1 1 1 1 1
Columns 71 through 80
1 1 1 1 1 1 1 1 1 1
a =
0.9750
xe =
6.0615 0.7585
5.1592 2.6908