one class SVM

背景:通常一類問題出現在需要對訓練樣本進行一定比例的篩選,或者已知的訓練樣本都是正樣本,而負樣本卻很少的情況。
這種情況下,往往需要訓練一個對於訓練樣本緊湊的分類邊界,就可以通過負樣本實驗。一個簡單的實際例子是:一個工廠對於產品的合格性進行檢查時,往往所知道是合格產品的參數,而不合格的產品的參數要麼空間比較大,要麼知道的很少。這種情況下就可以通過已知的合格產品參數來訓練一個一類分類器,得到一個緊湊的分類邊界,超出這個邊界就認爲是不合格產品。形式化的說明知乎上有個很好的例子:
http://www.zhihu.com/question/22365729
原理:
它求解的模型如下所示:
這裏寫圖片描述
其中的參數V特別注意,它的值在0-1之間,是一個比例值,它的含義是:你的訓練樣本中最後被分類爲負樣本的比例。比如你有100個訓練樣本,V設爲0.1,去學一個one class SVM,然後在學到的SVM上測試你之前用的訓練集,最後可以看到有10個左右的樣本的標籤爲-1,被分爲負樣本。下面還會通過實驗進行說明。模型的求解等其他理論細節請參考大牛Bernhard Scholkopf的文章:Support Vector Method for Novelty Detection
http://papers.nips.cc/paper/1723-support-vector-method-for-novelty-detection.pdf
實驗:
libsvm中有關於one class SVM的實現,算法就是根據上面Bernhard Scholkopf的文章。可以用下面代碼進行測試:

% Generate training data 
r = 20;
fprintf('Learning the function x^2 + y^2 < 1\n');
[XX,YY] = meshgrid(-1:1/r:1,-1:1/r:1);
X = [XX(:) YY(:)];
Y = 2*(sqrt(X(:,1).^2 + X(:,2).^2) < 1)-1 ;

Pin = find(Y==1);
X1 = X(Pin,:);
Y1 = Y(Pin);

% Plot training data
figure;hold on;
set(gca,'FontSize',16);    
for i = 1:length(Y1)
    if(Y1(i) > 0)
        plot(X1(i,1),X1(i,2),'g.');
    else
        plot(X1(i,1),X1(i,2),'r.');
    end
end

% Draw the decision boundary
theta = linspace(0,2*pi,100);
plot(cos(theta),sin(theta),'k-');
xlabel('x');
ylabel('y')
title(sprintf('Training data (%i points)',length(Y))); 
axis equal tight; box on;

通過上面的代碼可以構造一個在圓內的訓練樣本,如下圖所有:
這裏寫圖片描述
這樣我們就假設所有圓外的樣本爲負樣本,下面我們就進行one-class SVM訓練,可以看到有個參數-n,它就是我們上面提到的V,也就是這些訓練樣本中包含負樣本的比例。

%%train one class SVM with 10% instances to be set as outlier
model = svmtrain(Y1,X1,'-s 2 -n 0.01');
[Y1,Y2,Y3] = svmpredict(Y,X,model);

for i = 1:length(Y)
    if(Y1(i) > 0)
        plot(X(i,1),X(i,2),'g.');
    else
        plot(X(i,1),X(i,2),'r.');
    end
end

上面的代碼可以得到下面的結果:
這裏寫圖片描述
其中預測爲正的用綠色表示,預測爲負的用紅色表示。
爲了說明參數V的作用,將-n設爲0.3後可以得到如下結果:
這裏寫圖片描述
(黑線圓圈裏是我們的訓練樣本的點)
上面是對負樣本的檢測效果驗證,當一個新樣本在圓內時,我們還是希望分類器可以將新樣本分爲正樣本,下面就通過一段代碼驗證了這個效果:

X21 = rand(100,1);
X22 = 1-X21;
X2 = [X21 X22];
Y2 = ones(100,1);
% X = [X;X2];
% Y = [Y;Y2];
[Y1,Y2,Y3] = svmpredict(Y2,X2,model);

for i = 1:length(Y1)
    if(Y1(i) > 0)
        plot(X2(i,1),X2(i,2),'g.');
    else
        plot(X2(i,1),X2(i,2),'r.');
    end
end

上面的代碼得到的效果如下圖:
這裏寫圖片描述
可以看到這些新樣本的標籤都分對了(n設爲0.01的情況),另外,由於新樣本(x1,y1)滿足x1+y1=1所以顯示在一條直線上。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章