SVM簡單實例-A simple implementation of SVM using Matlab

本文是一個用Matlab實現的簡單的SVM實例,僅供參考,如有不足之處,歡迎指正。


主程序如下。

% simple SMO with 2D data
clear; close all; clc;

% generate two 'circles' of diameters 20 and 40, each consisting of 100 points
[X1, y1] = generateData(100, 2, 10, 1);
[X2, y2] = generateData(100, 2, 20, -1);
X = [X1; X2];
y = [y1; y2];
m = size(X,1);

xPos = X(y == 1,:);
xNeg = X(y == -1,:);

scatter(xPos(:,1),xPos(:,2),'y+');
hold on
scatter(xNeg(:,1),xNeg(:,2),'k+');
axis equal
hold off

% initialization of lagrange multipliers
alphas = zeros(m,1);
b = 0;

% punishment factor C
C = 1;

% update of lagrange multipliers
% the loop stops when alphas remain unchange during 5 iterations
% if |alpha_new - alpha| < delta, we don't update alpha
% num_updated = number of pairs of alpha updated during one loop
delta = 1e-10;
count_not_updated = 0;
iter = 0;
while(count_not_updated < 6)
    %w = zeros(1,2);
    w = (y .* alphas)' * X
    b
    [alphas, b, num_updated] = update_alphas(alphas, b, X, y, C, delta);
    if(num_updated == 0)
        count_not_updated = count_not_updated + 1;
    else
        count_not_updated = 0;
    end
    iter = iter + 1;
    fprintf('iteration = %d\n',iter);
    fprintf('count_not_updated = %d\n', count_not_updated);
end

% calculate w 
w = (y .* alphas)' * X;

% test
[res1, res2, res3, res4, res5] = test_svm(X, y, alphas, b);



下面兩個函數用於更新拉格朗日乘子。

function [alphas_new, b_new, updated] = update_alphas(alphas, b, X, y, C, delta)

alphas_new = alphas;
b_new = b;
updated = 0;

m = size(X,1);
index1 = 0;
index2 = 0;

% update alphas
for index = 1:1:m
    u2 = calculate_u(alphas_new, b_new, X, y, X(index,1:2));
    bool_2 = is_kkt(index, y, u2, alphas_new, C);
    if bool_2 == 0
        index2 = index;
        E2 = u2 - y(index2);
        diff_E = 0;
        for indexx = (index2 + 1):1:m
            u1 = calculate_u(alphas_new, b_new, X, y, X(indexx,1:2));
            bool_1 = is_kkt(indexx, y, u1, alphas_new, C);
            if bool_1 == 0
                E1 = u1 - y(indexx);
                dif = abs(E2 - E1);
                if dif > diff_E
                    diff_E = dif;
                    index1 = indexx;
                end        
            end
        end
        if index1 == 0
            index1 = index2 + 1;
        end
        [alpha1, alpha2, b_new, bool] = update_alpha_pair(index1,index2, alphas_new, b_new, X, y, C, delta);
        if bool == 1
            updated = updated + 1;
            alphas_new(index1) = alpha1;
            alphas_new(index2) = alpha2;
        end
        if bool == 0
            continue;
        end
    end
end

end

function [alpha1, alpha2, b, bool] = update_alpha_pair(index_a1,index_a2, alphas, b_, X, y, C, delta)
alpha1 = alphas(index_a1);
alpha2 = alphas(index_a2);
b = b_;

if index_a1 == index_a2
    bool = 0;
    return;
end

[L, H] = calculateLH(index_a1, index_a2, alphas, y ,C);
eta = calculateLs(index_a1, index_a2, X); %% relation with the kernel ?

x1 = X(index_a1,1:2);
x2 = X(index_a2,1:2);
y1 = y(index_a1);
y2 = y(index_a2);
u1 = calculate_u(alphas, b_, X, y, x1);
u2 = calculate_u(alphas, b_, X, y, x2);
E1 = u1 - y1;
E2 = u2 - y2;

% update alpha2
% eta > 0 ?
if eta > 0
    alpha2 = alphas(index_a2) + y2 * (E1 - E2) / eta;
    if alpha2 > H
        alpha2 = H;
    else if alpha2 < L
            alpha2 = L;
        end
    end
else
    bool = 0;
    return;
end

if abs(alphas(index_a2)-alpha2) < delta
    bool = 0;
    return;
end

alpha1 = alphas(index_a1) + y1 * y2 * (alphas(index_a2) - alpha2);

% update b
b1 = b_ - E1 - y1 * (alpha1 - alphas(index_a1)) * ker(x1,x1) - y2 * (alpha2 - alphas(index_a2)) * ker(x1,x2);
b2 = b_ - E2 - y1 * (alpha1 - alphas(index_a1)) * ker(x1,x2) - y2 * (alpha2 - alphas(index_a2)) * ker(x2,x2);
b = update_b(alpha1, alpha2, b1, b2, C);

bool = 1;

end

下面這個函數用於計算輸出函數u。

function u = calculate_u(alphas, b, X, y, x)
u = 0;
m = size(X,1);
Xk = zeros(m,1);

for index = 1:1:m
    Xk(index) = ker(X(index,1:2),x);
end

% prediction of x using alphas and b that we found
u = (alphas .* y)' * Xk + b;     
end

下面這個函數用於計算拉格朗日乘子的上下邊界。

function [L, H] = calculateLH(index1, index2, alphas, y ,C)

% calculate the borders of alpha
if y(index1) == y(index2)
   L = max(0, alphas(index2) + alphas(index1) - C);
   H = min(C, alphas(index2) + alphas(index1));
else
    L = max(0, alphas(index2) - alphas(index1));
    H = min(C, C + alphas(index2) - alphas(index1));
end

end
        

下面這個函數利用核函數計算二階導。

function Ls = calculateLs(index1, index2, X)
x1 = X(index1,:);
x2 = X(index2,:);

Ls = ker(x1, x1) + ker(x2, x2) - 2 * ker(x1, x2);
end

下面這個函數更新參數b。

function b = update_b(alpha1, alpha2, b1, b2, C)
if 0 < alpha1 && alpha1 < C
    b = b1;
else if 0 < alpha2 && alpha2 < C
        b = b2;
    else
        b = (b1 + b2)/2;
    end
end
end

下面這個函數用於判斷拉格朗日乘子是否滿足KKT條件。

function bool = is_kkt(index, y, u, alphas, C)
bool = 1;

% three cases where the lagrange multiplier does not satisfy the kkt
% condition
if y(index)*u <= 1 && alphas(index) < C
    bool = 0;
end

if y(index)*u >= 1 && alphas(index) > 0
    bool = 0;
end

if (y(index)*u == 1 && alphas(index) == 0) || (y(index)*u == 1 && alphas(index) == C)
    bool = 0;
end

end



下面這個函數用於定義核函數。

function kernel = ker(x1, x2)

% kernel function
% kernel = x1 * x2' + (x1.^2) * (x2.^2)' + x1(1) * x1(2) * x2(1) * x2(2);
kernel = (x1 * x2' + 1)^2;

end

下面這個函數用於測試。

function [res1, res2, res3, res4, res5] = test_svm(X, y, alphas, b)
m = size(X,1);
[Xt1, ~] = generateData(100, 2, 5, 0);
[Xt2, ~] = generateData(100, 2, 15, 0);
[Xt3, ~] = generateData(100, 2, 40, 0);

% Wrong test example
%res1 = Xt1 * w' + b;
%res2 = X(1:fix(m/2),:) * w' + b;
%res3 = Xt2 * w' + b;
%res4 = X(fix(m/2)+1:m,:) * w' + b;
%res5 = Xt3 * w' + b;

res1 = zeros(100,1);
res2 = zeros(100,1);
res3 = zeros(100,1);
res4 = zeros(100,1);
res5 = zeros(100,1);

for i = 1:100
    res1(i) = calculate_u(alphas, b, X, y, Xt1(i,:));
    res2(i) = calculate_u(alphas, b, X, y, X(i,:));
    res3(i) = calculate_u(alphas, b, X, y, Xt2(i,:));
    res4(i) = calculate_u(alphas, b, X, y, X(100+i,:));
    res5(i) = calculate_u(alphas, b, X, y, Xt3(i,:));
end

end


下面這個函數用於生成訓練和測試所用的數據。

function [data, labels] = generateData(m, n, r, label)
% we add some perturbations to the circle
perturbation = 1;
data = zeros(m, n);
labels = ones(m, 1) .* label;
for i = 1:m;
    data(i,1) = 2*r*rand - r;
    data(i,2) = (sqrt(r^2 - data(i,1)^2)) * (round(rand)*2-1) + rand * perturbation;
    %labels(i) = label;
end

end

生成的數據大致如下圖所示。



參考資料:

Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines, John C. Platt.

支持向量機通俗導論——理解SVM 的三層境界,July · pluskid (http://blog.csdn.net/v_july_v/article/details/7624837)


發佈了16 篇原創文章 · 獲贊 11 · 訪問量 4萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章