Machine Learning---Logistic迴歸

       本章節主要講解Logistic迴歸的原理及其數學推導,Logistic有3種不同的表達形式,現在我就一一展開這幾種不同的形式,以及它在分類中的效果。並比較這三種形式。 

下面分別寫出這三種形式的損失函數:


下面分別寫出這三種損失函數的梯度形式:


其中第一種形式和第三種形式是等價的,推導如下:



Steepest descent

   前面章節已經講過最速下降法的更新公式,如下:


下面將給出代碼這樣容易理解:

main.m

<span style="font-family:Times New Roman;">[D,b] = load_data();
%%% run exp and log convex logistic regression %%%
x0 = randn(3,1);    % initial point
alpha = 10^-2;        % step length
x = grad_descent_exp_logistic(D,b,x0,alpha);
% Run log convex logistic regression
alpha = 10^-1;        % step length
y = grad_descent_log_logistic(D,b,x0,alpha);
%%% plot everything, pts and lines %%%
plot_all(D',b,x,y);</span>


load_data().m

<span style="font-family:Times New Roman;"> function [A,b] = load_data()
        data = load('exp_vs_log_data.mat');
        data = data.data;
        A = data(:,1:3);
        A = A';
        b = data(:,4);
    end</span>

grad_descent_exp_logistic.m

<span style="font-family:Times New Roman;">function x = grad_descent_exp_logistic(D,b,x0,alpha)
        % Initializations
        x = x0;
        iter = 1;
        max_its = 3000;
        grad = 1;
        m=22;
        while  norm(grad) > 10^-6 && iter < max_its
            
            % compute gradient
              sum=0;
            for i=1:22
                z=b(i)*(D(:,i)'*x);
                tmp1=exp(-z);
                tmp2=-b(i)*D(:,i)';
                sum=sum+tmp1*tmp2';
            end
            grad=(1/22)*sum;         % your code goes here!
            x = x - alpha*grad;
            
            % update iteration count
            iter = iter + 1;
        end
    end</span>

grad_descent_log_logistic.m

<span style="font-family:Times New Roman;">function x = grad_descent_log_logistic(D,b,x0,alpha)
        % Initializations
        x = x0;
        iter = 1;
        max_its = 3000;
        grad = 1;
        m=22;
        while  norm(grad) > 10^-6 && iter < max_its
            sum=0;
            for i=1:22
                z=b(i)*(D(:,i)'*x);
                tmp1=exp(-z)/sigmoid(z);
                tmp2=-b(i)*D(:,i)';
                sum=sum+tmp1*tmp2';
            end
            grad=(1/22)*sum;
            x = x - alpha*grad;
            % update iteration count
            iter = iter + 1;
        end
    end</span>


plot_all.m

<span style="font-family:Times New Roman;">function plot_all(A,b,x,y)
        
        % plot points
        ind = find(b == 1);
        scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','b','markerFacecolor','none');
        hold on
        ind = find(b == -1);
        scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','r','markerFacecolor','none');
        hold on
        
        % plot separators
        s =[min(A(:,2)):.01:max(A(:,2))];
        plot (s,(-x(1)-x(2)*s)/x(3),'m','linewidth',2);
        hold on
        
        plot (s,(-y(1)-y(2)*s)/y(3),'k','linewidth',2);
        hold on
        
        set(gcf,'color','w');
        axis([ (min(A(:,2)) - 0.1) (max(A(:,2)) + 0.1) (min(A(:,3)) - 0.1) (max(A(:,3)) + 0.1)])
        box off
        
        % graph info labels
        xlabel('a_1','Fontsize',14)
        ylabel('a_2  ','Fontsize',14)
        set(get(gca,'YLabel'),'Rotation',0)
        
    end</span>

結果圖

其中黑線爲第二種損失函數,彩色線爲第一種損失函數。

                                                                                                                     資源----------------代碼和數據集見資源


                                                                                                                                    中科院大學雁西湖校區


發佈了59 篇原創文章 · 獲贊 60 · 訪問量 46萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章