下面分別寫出這三種形式的損失函數:
下面分別寫出這三種損失函數的梯度形式:
其中第一種形式和第三種形式是等價的,推導如下:
Steepest descent
前面章節已經講過最速下降法的更新公式,如下:
下面將給出代碼這樣容易理解:
main.m
<span style="font-family:Times New Roman;">[D,b] = load_data();
%%% run exp and log convex logistic regression %%%
x0 = randn(3,1); % initial point
alpha = 10^-2; % step length
x = grad_descent_exp_logistic(D,b,x0,alpha);
% Run log convex logistic regression
alpha = 10^-1; % step length
y = grad_descent_log_logistic(D,b,x0,alpha);
%%% plot everything, pts and lines %%%
plot_all(D',b,x,y);</span>
<span style="font-family:Times New Roman;"> function [A,b] = load_data()
data = load('exp_vs_log_data.mat');
data = data.data;
A = data(:,1:3);
A = A';
b = data(:,4);
end</span>
grad_descent_exp_logistic.m
<span style="font-family:Times New Roman;">function x = grad_descent_exp_logistic(D,b,x0,alpha)
% Initializations
x = x0;
iter = 1;
max_its = 3000;
grad = 1;
m=22;
while norm(grad) > 10^-6 && iter < max_its
% compute gradient
sum=0;
for i=1:22
z=b(i)*(D(:,i)'*x);
tmp1=exp(-z);
tmp2=-b(i)*D(:,i)';
sum=sum+tmp1*tmp2';
end
grad=(1/22)*sum; % your code goes here!
x = x - alpha*grad;
% update iteration count
iter = iter + 1;
end
end</span>
grad_descent_log_logistic.m
<span style="font-family:Times New Roman;">function x = grad_descent_log_logistic(D,b,x0,alpha)
% Initializations
x = x0;
iter = 1;
max_its = 3000;
grad = 1;
m=22;
while norm(grad) > 10^-6 && iter < max_its
sum=0;
for i=1:22
z=b(i)*(D(:,i)'*x);
tmp1=exp(-z)/sigmoid(z);
tmp2=-b(i)*D(:,i)';
sum=sum+tmp1*tmp2';
end
grad=(1/22)*sum;
x = x - alpha*grad;
% update iteration count
iter = iter + 1;
end
end</span>
<span style="font-family:Times New Roman;">function plot_all(A,b,x,y)
% plot points
ind = find(b == 1);
scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','b','markerFacecolor','none');
hold on
ind = find(b == -1);
scatter(A(ind,2),A(ind,3),'Linewidth',2,'Markeredgecolor','r','markerFacecolor','none');
hold on
% plot separators
s =[min(A(:,2)):.01:max(A(:,2))];
plot (s,(-x(1)-x(2)*s)/x(3),'m','linewidth',2);
hold on
plot (s,(-y(1)-y(2)*s)/y(3),'k','linewidth',2);
hold on
set(gcf,'color','w');
axis([ (min(A(:,2)) - 0.1) (max(A(:,2)) + 0.1) (min(A(:,3)) - 0.1) (max(A(:,3)) + 0.1)])
box off
% graph info labels
xlabel('a_1','Fontsize',14)
ylabel('a_2 ','Fontsize',14)
set(get(gca,'YLabel'),'Rotation',0)
end</span>
結果圖