背景:通常一類問題出現在需要對訓練樣本進行一定比例的篩選,或者已知的訓練樣本都是正樣本,而負樣本卻很少的情況。
這種情況下,往往需要訓練一個對於訓練樣本緊湊的分類邊界,就可以通過負樣本實驗。一個簡單的實際例子是:一個工廠對於產品的合格性進行檢查時,往往所知道是合格產品的參數,而不合格的產品的參數要麼空間比較大,要麼知道的很少。這種情況下就可以通過已知的合格產品參數來訓練一個一類分類器,得到一個緊湊的分類邊界,超出這個邊界就認爲是不合格產品。形式化的說明知乎上有個很好的例子:
http://www.zhihu.com/question/22365729
原理:
它求解的模型如下所示:
其中的參數V特別注意,它的值在0-1之間,是一個比例值,它的含義是:你的訓練樣本中最後被分類爲負樣本的比例。比如你有100個訓練樣本,V設爲0.1,去學一個one class SVM,然後在學到的SVM上測試你之前用的訓練集,最後可以看到有10個左右的樣本的標籤爲-1,被分爲負樣本。下面還會通過實驗進行說明。模型的求解等其他理論細節請參考大牛Bernhard Scholkopf的文章:Support Vector Method for Novelty Detection
http://papers.nips.cc/paper/1723-support-vector-method-for-novelty-detection.pdf
實驗:
libsvm中有關於one class SVM的實現,算法就是根據上面Bernhard Scholkopf的文章。可以用下面代碼進行測試:
% Generate training data
r = 20;
fprintf('Learning the function x^2 + y^2 < 1\n');
[XX,YY] = meshgrid(-1:1/r:1,-1:1/r:1);
X = [XX(:) YY(:)];
Y = 2*(sqrt(X(:,1).^2 + X(:,2).^2) < 1)-1 ;
Pin = find(Y==1);
X1 = X(Pin,:);
Y1 = Y(Pin);
% Plot training data
figure;hold on;
set(gca,'FontSize',16);
for i = 1:length(Y1)
if(Y1(i) > 0)
plot(X1(i,1),X1(i,2),'g.');
else
plot(X1(i,1),X1(i,2),'r.');
end
end
% Draw the decision boundary
theta = linspace(0,2*pi,100);
plot(cos(theta),sin(theta),'k-');
xlabel('x');
ylabel('y')
title(sprintf('Training data (%i points)',length(Y)));
axis equal tight; box on;
通過上面的代碼可以構造一個在圓內的訓練樣本,如下圖所有:
這樣我們就假設所有圓外的樣本爲負樣本,下面我們就進行one-class SVM訓練,可以看到有個參數-n,它就是我們上面提到的V,也就是這些訓練樣本中包含負樣本的比例。
%%train one class SVM with 10% instances to be set as outlier
model = svmtrain(Y1,X1,'-s 2 -n 0.01');
[Y1,Y2,Y3] = svmpredict(Y,X,model);
for i = 1:length(Y)
if(Y1(i) > 0)
plot(X(i,1),X(i,2),'g.');
else
plot(X(i,1),X(i,2),'r.');
end
end
上面的代碼可以得到下面的結果:
其中預測爲正的用綠色表示,預測爲負的用紅色表示。
爲了說明參數V的作用,將-n設爲0.3後可以得到如下結果:
(黑線圓圈裏是我們的訓練樣本的點)
上面是對負樣本的檢測效果驗證,當一個新樣本在圓內時,我們還是希望分類器可以將新樣本分爲正樣本,下面就通過一段代碼驗證了這個效果:
X21 = rand(100,1);
X22 = 1-X21;
X2 = [X21 X22];
Y2 = ones(100,1);
% X = [X;X2];
% Y = [Y;Y2];
[Y1,Y2,Y3] = svmpredict(Y2,X2,model);
for i = 1:length(Y1)
if(Y1(i) > 0)
plot(X2(i,1),X2(i,2),'g.');
else
plot(X2(i,1),X2(i,2),'r.');
end
end
上面的代碼得到的效果如下圖:
可以看到這些新樣本的標籤都分對了(n設爲0.01的情況),另外,由於新樣本(x1,y1)滿足x1+y1=1所以顯示在一條直線上。