struct跟蹤算法

博客:http://blog.csdn.net/qianxin_dh


     《Struck:Structured Output Tracking with Kernels》是 Sam Hare, Amir Saffari, Philip H. S. Torr等人於2011年發表在Computer Vision (ICCV)上的一篇文章。Struck與傳統跟蹤算法的不同之處在於:傳統跟蹤算法(下圖右手邊)將跟蹤問題轉化爲一個分類問題,並通過在線學習技術更新目標模型。然而,爲了達到更新的目的,通常需要將一些預估計的目標位置作爲已知類別的訓練樣本,這些分類樣本並不一定與實際目標一致,因此難以實現最佳的分類效果。而Struck算法(下圖左手邊)主要提出一種基於結構輸出預測的自適應視覺目標跟蹤的框架,通過明確引入輸出空間滿足跟蹤功能,能夠避免中間分類環節,直接輸出跟蹤結果。同時,爲了保證實時性,該算法還引入了閾值機制,防止跟蹤過程中支持向量的過增長。



      最後,大牛作者們也提供了c++版源代碼,有興趣的朋友可以下載下來,體驗下該算法的強大代碼下載地址:https://github.com/gnebehay/STRUCK,代碼調試需要Opencv2.1以上版本以及Eigen v2.0.15。


         在論文理解以及代碼調試的過程中,主要參考了以下資料,感到獲益匪淺,列舉如下:
        
         1) Eigen的安裝及學習
         
         2) 支持向量機通俗導論(理解SVM的三重境界)

               http://blog.csdn.net/v_july_v/article/details/7624837

第二部分

目前在目標檢測任務中,由於svm自身具有較好的推廣能力以及對分類的魯棒性,得到了越來越多的應用。 Struck算法便使用了在線結構輸出svm學習方法去解決跟蹤問題。不同於常規算法訓練一個分類器,Struck算法直接通過預測函數:,來預測每幀之間目標位置發生的變化,其中表示搜尋空間,例如,,上一幀中目標的新位置爲Pt-1,則在當前幀中,目標位置就爲(可見其實就是表示幀間目標位置變化關係的集合)。因此,在Struck算法中,已知類型的樣本用(x,y)表示,而不再是(x,+1)或者(x,-1)了。

        那麼y預測函數怎麼獲得呢?這就需要用到結構輸出svm方法了(svm基本概念學習可參考我上篇文章中給出的svm三重境界的鏈接),它在該算法中引入了一個判別函數,通過公式找到概率最大的對目標位置進行預測,也就是說,因爲我們還不知道當前幀的目標位置,那麼我們首先想到在上一幀中的目標能夠通過一些位置變化關係,出現在當前幀中的各處,但是呢,實際的目標只有一個,所以這些變換關係中也必然只有一個是最佳的,因此,我們需要找到這個最佳的,並通過,就可以成功找到目標啦~,至於搜尋空間如何取,在程序解讀時大家就會看到了。

       那麼如何找到呢?我個人理解是:將判別函數形式轉換爲:,其中,表示映射函數,是從輸入空間到某個特徵空間的映射,進而實現對樣本線性可分。因爲當分類平面(輸入空間中的超平面)離數據點的“間隔”越大,分類的確信度越大,所以需讓所選擇的分類平面最大化這個“間隔”值,這裏我們通過最小化凸目標函數來實現,該函數應滿足條件:,其中,表示兩個框之間的覆蓋率)。優化的目的是確保F(目標)>>F(非目標)。
        
       接下來問題又來了,如何獲得最小的w??文中採取的求解方式是利用拉格朗日對偶性,通過求解與原問題等價的對偶問題(dual problem),得到原始問題的最優解。通過給每一個約束條件加上一個拉格朗日乘子alpha,定義拉格朗日函數L(w,b,alpha)。一般對偶問題的求解過程如下:1)固定alpha,求L關於w,b的最小化。2)求L對alpha的極大。3)利用SMO算法求得拉格朗日乘子alpha。爲了簡化對偶問題求解,這裏定義了參數beta,可見論文中的Eq.(8)。

算法主要流程:

1.   首先讀入config.txt,初始化程序參數,這一過程主要由Config類實現;


2.   判斷是否使用攝像頭進行跟蹤,如使用攝像頭進行跟蹤,則initBB=(120,80,80,80);

      若使用視頻序列進行跟蹤,initBB由相應txt文件給出;


3.   將讀入的每幀圖像統一爲320*240。


4.   由當前第一幀以及框initBB,實現對跟蹤算法的初始化。


4.1   Initialise(frame,bb)

        由於我們之前獲取的initBB的座標定義爲float型,在這裏首先將其轉換爲int型。
        程序中選取haar特徵,gaussian核函數, 初始化參數m_needsIntegralImage=true,m_needsIntegralHist=false。因此在這裏,ImageRep image()主要實現了積分圖的計算(如果特徵爲histogram,則可實現積分直方圖的計算)。ImageRep類中的類成員包括frame,積分圖,積分直方圖。

4.2   UpdateLearner(image)

       該函數主要實現對預測函數的更新,首先通過RadialSamples()獲得5*16=80個樣本,再加上原始目標,總共含有81個樣本。之後判斷這81個樣本是否有超出圖像邊界的,超出的捨棄。將剩餘的樣本存入keptRects,其中,原始目標樣本存入keptRects[0]。定義一個多樣本類MultiSample,該類中的類成員主要包括樣本框以及ImageRep image。並通過Update(sample,0)來實現預測函數的更新。

4.3 Update(sample,0)

       該函數定義在LaRank類下,文章中參考文獻《Solving multiclass support vector machines with LaRank》提到了這種算法。當我們分析LaRank頭文件時,可看到struck算法重要步驟全部聚集在這個類中。該類中的類成員包括支持模式SupportPattern,支持向量SupportVector,Config類對象m_config,Features類對象m_features,Kernel類對象m_kernel,存放SupportPattern的m_sps,存放SupportVector的m_svs,用於顯示的m_debugImage,目標函數中的係數m_C,矩陣m_K。

       查看SupportPattern的定義,我們知道該結構主要包括x(存放特徵值),yv(,存放目標變化關係),images(存放圖片樣本),y(索引值,表明指定樣本存放位置),refCount(統計sv的個數??)。同樣,查看SupportVector的定義可知,該結構包括一個SupportPattern,y(索引值,表明指定樣本存放位置),b(beta),g(gradient),image(存放圖片樣本)。

       在函數Update(sample,0)中,定義了一個SupportPattern* sp。首先對於每個樣本框,其x,y座標分別減去原始目標框的x,y座標,將結果存入sp->yv。然後對於每個樣本框內的圖片統一尺寸爲30*30,並存入sp->images。對於每個樣本框,計算其haar特徵值,並存入sp->x。令sp->y=y=0,sp->refCount=0,最後將當前sp存入m_sps。


4.3.1  ProcessNew(int ind)

       之後執行ProcessNew(int ind),其中ind=m_sps.size()-1。由於每處理一幀圖像,m_sps的數量都增加1,這樣定義ind能夠保證ProcessNew所處理的樣本都是最新的樣本。ProcessNew處理之前,首先看函數AddSupportVector(SupportPattern* x,int y,double g)的定義:
         SupportVector* sv=new SupportVector;定義了一個支持向量。
       爲支持向量賦初值:sv->b=0.0,sv->x=x,sv->y=y,sv->g=g,並將該向量存入m_svs。接下來通過調用Kernel類中的Eval()函數更新核矩陣,即m_K,以後用於Algorithm 1計算。

現在再回到ProcessNew函數:
       第一個AddSupportVector(),將目標框作爲參數,增加一個支持向量存入m_svs,此時,m_svs.size()=1,m_K(0,0)=1.0,函數返回ip=0。
         之後執行MinGradiernt(int ind),求得公式10中的g最小值。返回最小梯度的數值以及對應的樣本框存放位置。
         第二個AddSupportVector(),將具有最小梯度的樣本框作爲參數,增加一個特徵向量存入m_svs,此時,m_svs.size()=2,並求得m_K(0,1),m_K(1,0),m_K(1,1)。函數返回in=1。
之後進行SMO算法進行計算,若某向量的beta值爲0,則捨棄該支持向量。


4.3.2  BudgetMaintenance()

       再之後執行函數BudgetMaintenance(),保證支持向量個數沒有超過100。


4.3.3  Reprocess()

       進行Reprocess()步驟,一個Reprocess()包括1個ProcessOld()和10個Optimize();

       ProcessOld()主要對已經存在的SupportPattern進行隨機選取並處理。和ProcessNew不同的地方是,這裏將滿足梯度最大以及滿足的支持向量作爲正支持向量。負支持向量依然根據梯度最小進行選取。之後再次執行SMO算法,判斷這些支持向量是否有效。

       Optimize()也是對已經存在的SupportPattern進行隨機選取並處理,但僅僅是對現有的支持向量的beta值進行調整,並不加入新的支持向量。正負支持向量的選取方式和ProcessOld()一樣。


4.3.4  BudgetMaintenance()

       執行函數BudgetMaintenance(),保證支持向量個數沒有超過100。

5.跟蹤模塊(Algorithm 2)

       首先通過ImageRep image()實現積分圖的計算,然後進行抽樣(這裏抽樣的結果和初始化時的抽樣結果不一樣,大概抽取幾千個樣本)。將超出圖像範圍的框捨棄,剩餘的保留在keptRects中。對keptRects中的每一個框,計算F函數,即,將結果保存在scores裏,並記錄值最大的那一個,將其作爲跟蹤結果。  UpdateDebugImage()函數主要實現程序運行時顯示的界面。UpdateLearner(image)同步驟4一致。

6.Debug()   顯示樣本圖像,綠色邊框的是正樣本,紅色邊框的負樣本。

第三部分 代碼

main.cpp
[cpp] view plain copy
  1. /*  
  2. * Struck: Structured Output Tracking with Kernels 
  3.  
  4. * Code to accompany the paper: 
  5. *   Struck: Structured Output Tracking with Kernels 
  6. *   Sam Hare, Amir Saffari, Philip H. S. Torr 
  7. *   International Conference on Computer Vision (ICCV), 2011 
  8.  
  9. * Copyright (C) 2011 Sam Hare, Oxford Brookes University, Oxford, UK 
  10.  
  11. * This file is part of Struck. 
  12.  
  13. * Struck is free software: you can redistribute it and/or modify 
  14. * it under the terms of the GNU General Public License as published by 
  15. * the Free Software Foundation, either version 3 of the License, or 
  16. * (at your option) any later version. 
  17.  
  18. * Struck is distributed in the hope that it will be useful, 
  19. * but WITHOUT ANY WARRANTY; without even the implied warranty of 
  20. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
  21. * GNU General Public License for more details. 
  22.  
  23. * You should have received a copy of the GNU General Public License 
  24. * along with Struck.  If not, see <http://www.gnu.org/licenses/>. 
  25.  
  26. */  
  27.   
  28. #include "Tracker.h"  
  29. #include "Config.h"  
  30.   
  31. #include <iostream>  
  32. #include <fstream>  
  33.   
  34. #include <opencv/cv.h>  
  35. #include <opencv/highgui.h>  
  36.   
  37. #include "vot.hpp"  
  38.   
  39. using namespace std;  
  40. using namespace cv;  
  41.   
  42. static const int kLiveBoxWidth = 80;         
  43. static const int kLiveBoxHeight = 80;        
  44.   
  45. void rectangle(Mat& rMat, const FloatRect& rRect, const Scalar& rColour)  
  46. {  
  47.     IntRect r(rRect);  
  48.     rectangle(rMat, Point(r.XMin(), r.YMin()), Point(r.XMax(), r.YMax()), rColour);  
  49. }  
  50.   
  51. int main(int argc, char* argv[])  
  52. {  
  53.     // 讀取文件對程序參數進行初始化  
  54.     string configPath = "config.txt";  
  55.     if (argc > 1)  
  56.     {  
  57.         configPath = argv[1];  
  58.     }  
  59.     Config conf(configPath);       //Config類主要讀取config.txt中的參數  
  60.   
  61.     if (conf.features.size() == 0)  
  62.     {  
  63.         cout << "error: no features specified in config" << endl;  
  64.         return EXIT_FAILURE;  
  65.     }  
  66.   
  67.     Tracker tracker(conf);     
  68.   
  69.     //Check if --challenge was passed as an argument  
  70.     bool challengeMode = false;  
  71.     for (int i = 1; i < argc; i++) {  
  72.         if (strcmp("--challenge", argv[i]) == 0) {        //判斷是否有挑戰模式(vot挑戰)  
  73.             challengeMode = true;  
  74.         }  
  75.     }  
  76.   
  77.     if (challengeMode) {    //VOT(Visual object tracking)挑戰,它提供了一個公共平臺,目標是比較各種跟蹤算法再短期跟蹤內的性能,討論視覺跟蹤領域的發展。  
  78.         //load region, images and prepare for output  
  79.         Mat frameOrig;  
  80.         Mat frame;  
  81.         VOT vot_io("region.txt""images.txt""output.txt");  
  82.         vot_io.getNextImage(frameOrig);  
  83.         resize(frameOrig, frame, Size(conf.frameWidth, conf.frameHeight));  
  84.         cv::Rect initPos = vot_io.getInitRectangle();  
  85.         vot_io.outputBoundingBox(initPos);  
  86.         float scaleW = (float)conf.frameWidth/frameOrig.cols;  
  87.         float scaleH = (float)conf.frameHeight/frameOrig.rows;  
  88.   
  89.         FloatRect initBB_vot = FloatRect(initPos.x*scaleW, initPos.y*scaleH, initPos.width*scaleW, initPos.height*scaleH);  
  90.         tracker.Initialise(frame, initBB_vot);  
  91.   
  92.         while (vot_io.getNextImage(frameOrig) == 1){  
  93.             resize(frameOrig, frame, Size(conf.frameWidth, conf.frameHeight));  
  94.   
  95.             tracker.Track(frame);  
  96.             const FloatRect& bb = tracker.GetBB();  
  97.             float x = bb.XMin()/scaleW;  
  98.             float y = bb.YMin()/scaleH;  
  99.             float w = bb.Width()/scaleW;  
  100.             float h = bb.Height()/scaleH;  
  101.   
  102.             cv::Rect output = cv::Rect(x,y,w,h);  
  103.   
  104.             vot_io.outputBoundingBox(output);  
  105.         }  
  106.   
  107.         return 0;  
  108.     }  
  109.   
  110.       
  111.     ofstream outFile;  
  112.     if (conf.resultsPath != "")  
  113.     {  
  114.         outFile.open(conf.resultsPath.c_str(), ios::out);    //將程序寫入resultpath  
  115.         if (!outFile)  
  116.         {  
  117.             cout << "error: could not open results file: " << conf.resultsPath << endl;  
  118.             return EXIT_FAILURE;  
  119.         }  
  120.     }  
  121.   
  122.     // if no sequence specified then use the camera  
  123.     bool useCamera = (conf.sequenceName == "");  
  124.   
  125.     VideoCapture cap;  
  126.   
  127.     int startFrame = -1;  
  128.     int endFrame = -1;  
  129.     FloatRect initBB;  
  130.     string imgFormat;  
  131.     float scaleW = 1.f;  
  132.     float scaleH = 1.f;  
  133.   
  134.     if (useCamera)  
  135.     {  
  136.         if (!cap.open(0))  
  137.         {  
  138.             cout << "error: could not start camera capture" << endl;  
  139.             return EXIT_FAILURE;  
  140.         }  
  141.         startFrame = 0;  
  142.         endFrame = INT_MAX;        /* maximum (signed) int value */  
  143.         Mat tmp;  
  144.         cap >> tmp;  
  145.         scaleW = (float)conf.frameWidth/tmp.cols;  
  146.         scaleH = (float)conf.frameHeight/tmp.rows;  
  147.   
  148.         initBB = IntRect(conf.frameWidth/2-kLiveBoxWidth/2, conf.frameHeight/2-kLiveBoxHeight/2, kLiveBoxWidth, kLiveBoxHeight);  
  149.         cout << "press 'i' to initialise tracker" << endl;  
  150.     }  
  151.     else  
  152.     {  
  153.         // parse frames file  
  154.         string framesFilePath = conf.sequenceBasePath+"/"+conf.sequenceName+"/"+conf.sequenceName+"_frames.txt";  //girl_frames.txt的文件路徑,該文件放在girl文件夾裏,內容爲0,501。  
  155.         ifstream framesFile(framesFilePath.c_str(), ios::in);     
  156.         if (!framesFile)  
  157.         {  
  158.             cout << "error: could not open sequence frames file: " << framesFilePath << endl;  
  159.             return EXIT_FAILURE;  
  160.         }  
  161.       
  162.         string framesLine;  
  163.         getline(framesFile, framesLine);  
  164.         sscanf(framesLine.c_str(), "%d,%d", &startFrame, &endFrame);   //startFrame=0;endFrame=501;  
  165.   
  166.         if (framesFile.fail() || startFrame == -1 || endFrame == -1)     
  167.         {  
  168.             cout << "error: could not parse sequence frames file" << endl;  
  169.             return EXIT_FAILURE;  
  170.         }  
  171.   
  172.         imgFormat = conf.sequenceBasePath+"/"+conf.sequenceName+"/imgs/img%05d.png";  
  173.   
  174.         // read first frame to get size  
  175.         char imgPath[256];  
  176.         sprintf(imgPath, imgFormat.c_str(), startFrame);  //sprintf把格式化的數據寫入某個字符串緩衝區(imgPath);  
  177.         Mat tmp = cv::imread(imgPath, 0);  
  178.         scaleW = (float)conf.frameWidth/tmp.cols;   //=1;  
  179.         scaleH = (float)conf.frameHeight/tmp.rows; //=1;  
  180.   
  181.         // read init box from ground truth file  
  182.         string gtFilePath = conf.sequenceBasePath+"/"+conf.sequenceName+"/"+conf.sequenceName+"_gt.txt";  //讀取girl_gt.txt文件  
  183.         ifstream gtFile(gtFilePath.c_str(), ios::in);  
  184.         if (!gtFile)  
  185.         {  
  186.             cout << "error: could not open sequence gt file: " << gtFilePath << endl;  
  187.             return EXIT_FAILURE;  
  188.         }  
  189.   
  190.       
  191.         string gtLine;  
  192.         getline(gtFile, gtLine);  
  193.         float xmin = -1.f;  
  194.         float ymin = -1.f;  
  195.         float width = -1.f;  
  196.         float height = -1.f;  
  197.         sscanf(gtLine.c_str(), "%f,%f,%f,%f", &xmin, &ymin, &width, &height);  //128,46,104,127  
  198.   
  199.         if (gtFile.fail() || xmin < 0.f || ymin < 0.f || width < 0.f || height < 0.f)  
  200.         {  
  201.             cout << "error: could not parse sequence gt file" << endl;  
  202.             return EXIT_FAILURE;  
  203.         }  
  204.         initBB = FloatRect(xmin*scaleW, ymin*scaleH, width*scaleW, height*scaleH);  
  205.     }  
  206.   
  207.     if (!conf.quietMode)  
  208.     {  
  209.         namedWindow("result");  
  210.     }  
  211.   
  212.     Mat result(conf.frameHeight, conf.frameWidth, CV_8UC3);  
  213.     bool paused = false;  
  214.     bool doInitialise = false;  
  215.     srand(conf.seed);  
  216.   
  217.     for (int frameInd = startFrame; frameInd <= endFrame; ++frameInd)    //逐幀處理  
  218.     {  
  219.         Mat frame;  
  220.         if (useCamera)   //若使用攝像頭  
  221.         {  
  222.             Mat frameOrig;  
  223.             cap >> frameOrig;  
  224.             resize(frameOrig, frame, Size(conf.frameWidth, conf.frameHeight));  
  225.             flip(frame, frame, 1);  
  226.             frame.copyTo(result);  
  227.             if (doInitialise)  
  228.             {  
  229.                 if (tracker.IsInitialised())  
  230.                 {  
  231.                     tracker.Reset();  
  232.                 }  
  233.                 else  
  234.                 {  
  235.                     tracker.Initialise(frame, initBB);  
  236.                 }  
  237.                 doInitialise = false;  
  238.             }  
  239.             else if (!tracker.IsInitialised())  
  240.             {  
  241.                 rectangle(result, initBB, CV_RGB(255, 255, 255));  
  242.             }  
  243.         }  
  244.         else    //若讀取圖片序列  
  245.         {             
  246.             char imgPath[256];  
  247.             sprintf(imgPath, imgFormat.c_str(), frameInd);  
  248.             Mat frameOrig = cv::imread(imgPath, 0);  
  249.             if (frameOrig.empty())  
  250.             {  
  251.                 cout << "error: could not read frame: " << imgPath << endl;  
  252.                 return EXIT_FAILURE;  
  253.             }  
  254.   
  255.             resize(frameOrig, frame, Size(conf.frameWidth, conf.frameHeight));  //將讀取的每幀圖像統一爲320*240;  
  256.             cvtColor(frame, result, CV_GRAY2RGB);  
  257.                           
  258.                                                     if (frameInd == startFrame)  
  259.             {  
  260.                 tracker.Initialise(frame, initBB);                 //對第一幀進行初始化  
  261.             }  
  262.         }  
  263.   
  264.         if (tracker.IsInitialised())   
  265.         {  
  266.             tracker.Track(frame);                     //開始跟蹤  
  267.   
  268.             if (!conf.quietMode && conf.debugMode)  
  269.             {  
  270.                 tracker.Debug();  //用於顯示樣本圖像  
  271.             }  
  272.   
  273.             rectangle(result, tracker.GetBB(), CV_RGB(0, 255, 0));  
  274.   
  275.             if (outFile)  
  276.             {  
  277.                 const FloatRect& bb = tracker.GetBB();  
  278.                 outFile << bb.XMin()/scaleW << "," << bb.YMin()/scaleH << "," << bb.Width()/scaleW << "," << bb.Height()/scaleH << endl;  
  279.             }   //輸出跟蹤結果座標  
  280.         }  
  281.   
  282.         if (!conf.quietMode)  
  283.         {  
  284.             imshow("result", result);   //顯示跟蹤畫面  
  285.             int key = waitKey(paused ? 0 : 1);  
  286.   
  287.             if (key != -1)  
  288.             {  
  289.                 if (key == 27 || key == 113) // esc q  
  290.                 {  
  291.                     break;  
  292.                 }  
  293.                 else if (key == 112) // p  
  294.                 {  
  295.                     paused = !paused;  
  296.                 }  
  297.                 else if (key == 105 && useCamera)  
  298.                 {  
  299.                     doInitialise = true;  
  300.                 }  
  301.             }  
  302.             if (conf.debugMode && frameInd == endFrame)  
  303.             {  
  304.                 cout << "\n\nend of sequence, press any key to exit" << endl;  
  305.                 waitKey();  
  306.             }  
  307.         }  
  308.     }  
  309.   
  310.     if (outFile.is_open())  
  311.     {  
  312.         outFile.close();  
  313.     }  
  314.   
  315.     return EXIT_SUCCESS;  
  316. }  


Tracker.cpp
[cpp] view plain copy
  1. #include "Tracker.h"  
  2. #include "Config.h"  
  3. #include "ImageRep.h"  
  4. #include "Sampler.h"  
  5. #include "Sample.h"  
  6. #include "GraphUtils/GraphUtils.h"  
  7.   
  8. #include "HaarFeatures.h"  
  9. #include "RawFeatures.h"  
  10. #include "HistogramFeatures.h"  
  11. #include "MultiFeatures.h"  
  12.   
  13. #include "Kernels.h"  
  14.   
  15. #include "LaRank.h"  
  16.   
  17. #include <opencv/cv.h>  
  18. #include <opencv/highgui.h>  
  19.   
  20. #include <Eigen/Core>  
  21.   
  22. #include <vector>  
  23. #include <algorithm>  
  24.   
  25. using namespace cv;  
  26. using namespace std;  
  27. using namespace Eigen;  
  28.   
  29. Tracker::Tracker(const Config& conf) :      //構造函數,對參數進行初始化  
  30.     m_config(conf),  
  31.     m_initialised(false),  
  32.     m_pLearner(0),  
  33.     m_debugImage(2*conf.searchRadius+1, 2*conf.searchRadius+1, CV_32FC1),  
  34.     m_needsIntegralImage(false)  
  35. {  
  36.     Reset();  
  37. }  
  38.   
  39. Tracker::~Tracker()  
  40. {  
  41.     delete m_pLearner;  
  42.     for (int i = 0; i < (int)m_features.size(); ++i)  
  43.     {  
  44.         delete m_features[i];  
  45.         delete m_kernels[i];  
  46.     }  
  47. }  
  48.   
  49. void Tracker::Reset()               //因爲初始化爲haar特徵核高斯核函數,所以m_needsIntegralImage = true,m_needsIntegralHist = false;  
  50. {  
  51.     m_initialised = false;  
  52.     m_debugImage.setTo(0);  
  53.     if (m_pLearner) delete m_pLearner;  
  54.     for (int i = 0; i < (int)m_features.size(); ++i)  
  55.     {  
  56.         delete m_features[i];  
  57.         delete m_kernels[i];  
  58.     }  
  59.     m_features.clear();  
  60.     m_kernels.clear();  
  61.       
  62.     m_needsIntegralImage = false;  
  63.     m_needsIntegralHist = false;  
  64.       
  65.     int numFeatures = m_config.features.size();  
  66.     vector<int> featureCounts;  
  67.     for (int i = 0; i < numFeatures; ++i)  
  68.     {  
  69.         switch (m_config.features[i].feature)  
  70.         {  
  71.         case Config::kFeatureTypeHaar:  
  72.             m_features.push_back(new HaarFeatures(m_config));  
  73.             m_needsIntegralImage = true;  
  74.             break;            
  75.         case Config::kFeatureTypeRaw:  
  76.             m_features.push_back(new RawFeatures(m_config));  
  77.             break;  
  78.         case Config::kFeatureTypeHistogram:  
  79.             m_features.push_back(new HistogramFeatures(m_config));  
  80.             m_needsIntegralHist = true;  
  81.             break;  
  82.         }  
  83.         featureCounts.push_back(m_features.back()->GetCount());  
  84.           
  85.         switch (m_config.features[i].kernel)  
  86.         {  
  87.         case Config::kKernelTypeLinear:  
  88.             m_kernels.push_back(new LinearKernel());  
  89.             break;  
  90.         case Config::kKernelTypeGaussian:  
  91.             m_kernels.push_back(new GaussianKernel(m_config.features[i].params[0]));  
  92.             break;  
  93.         case Config::kKernelTypeIntersection:  
  94.             m_kernels.push_back(new IntersectionKernel());  
  95.             break;  
  96.         case Config::kKernelTypeChi2:  
  97.             m_kernels.push_back(new Chi2Kernel());  
  98.             break;  
  99.         }  
  100.     }  
  101.       
  102.     if (numFeatures > 1)  
  103.     {  
  104.         MultiFeatures* f = new MultiFeatures(m_features);  
  105.         m_features.push_back(f);  
  106.           
  107.         MultiKernel* k = new MultiKernel(m_kernels, featureCounts);  
  108.         m_kernels.push_back(k);       
  109.     }  
  110.       
  111.     m_pLearner = new LaRank(m_config, *m_features.back(), *m_kernels.back());  
  112. }  
  113.       
  114.   
  115. void Tracker::Initialise(const cv::Mat& frame, FloatRect bb)  
  116. {  
  117.     m_bb = IntRect(bb);//將目標框座標轉爲int型  
  118.     //該類主要實現了積分圖計算  
  119.     ImageRep image(frame, m_needsIntegralImage, m_needsIntegralHist);  //後兩個參數分別爲true,false  
  120.     for (int i = 0; i < 1; ++i)  
  121.     {  
  122.         UpdateLearner(image);// 更新預測函數  
  123.     }  
  124.     m_initialised = true;  
  125. }  
  126.   
  127. void Tracker::Track(const cv::Mat& frame)  
  128. {  
  129.     assert(m_initialised);  
  130.       
  131.     ImageRep image(frame, m_needsIntegralImage, m_needsIntegralHist);   //獲得當前幀的積分圖  
  132.       
  133.     vector<FloatRect> rects = Sampler::PixelSamples(m_bb, m_config.searchRadius);  //抽樣  
  134.       
  135.     vector<FloatRect> keptRects;  
  136.     keptRects.reserve(rects.size());  
  137.     for (int i = 0; i < (int)rects.size(); ++i)  
  138.     {  
  139.         if (!rects[i].IsInside(image.GetRect())) continue;  
  140.         keptRects.push_back(rects[i]);        //將超出圖像範圍的框捨棄,剩餘的保留在keptRects中  
  141.     }  
  142.       
  143.     MultiSample sample(image, keptRects);     //多樣本類,主要包括樣本框以及ImageRep image  
  144.       
  145.     vector<double> scores;  
  146.     m_pLearner->Eval(sample, scores);     //scores裏存放的是論文中公式(10)後半部分  
  147.       
  148.     double bestScore = -DBL_MAX;  
  149.     int bestInd = -1;  
  150.     for (int i = 0; i < (int)keptRects.size(); ++i)  
  151.     {         
  152.         if (scores[i] > bestScore)  
  153.         {  
  154.             bestScore = scores[i];  
  155.             bestInd = i;              //找到bestScore  
  156.         }  
  157.     }  
  158.       
  159.     UpdateDebugImage(keptRects, m_bb, scores);//更新debug圖像,用於顯示  
  160.       
  161.     if (bestInd != -1)  
  162.     {  
  163.         m_bb = keptRects[bestInd];  
  164.         UpdateLearner(image);  
  165. #if VERBOSE       
  166.         cout << "track score: " << bestScore << endl;  
  167. #endif  
  168.     }  
  169. }  
  170.   
  171. void Tracker::UpdateDebugImage(const vector<FloatRect>& samples, const FloatRect& centre, const vector<double>& scores)  
  172. {  
  173.     double mn = VectorXd::Map(&scores[0], scores.size()).minCoeff();   //Map:將現存的結構映射到Eigen的數據結構裏,進行計算  
  174.     double mx = VectorXd::Map(&scores[0], scores.size()).maxCoeff();   //R.minCoeff()=min(R(:)), R.maxCoeff()=max(R(:));  
  175.     m_debugImage.setTo(0);     //置爲全黑色  
  176.     for (int i = 0; i < (int)samples.size(); ++i)  
  177.     {  
  178.         int x = (int)(samples[i].XMin() - centre.XMin());  
  179.         int y = (int)(samples[i].YMin() - centre.YMin());  
  180.         m_debugImage.at<float>(m_config.searchRadius+y,m_config.searchRadius+x)=(float)((scores[i]-mn)/(mx-mn));//scores得分越大的框,會在m_debugImage上具有越大的值,即該點越亮(類似於置信圖)  
  181.     }  
  182. }  
  183.   
  184. void Tracker::Debug()  
  185. {  
  186.     imshow("tracker", m_debugImage);   //顯示m_debugImage圖像  
  187.     m_pLearner->Debug();  
  188. }  
  189.   
  190. void Tracker::UpdateLearner(const ImageRep& image)     //更新預測函數  
  191. {  
  192.     // note these return the centre sample at index 0  
  193.     vector<FloatRect> rects = Sampler::RadialSamples(m_bb, 2*m_config.searchRadius, 5, 16);//5*16=80,加上一個原始rect,共包含81個rect  
  194.     //vector<FloatRect> rects = Sampler::PixelSamples(m_bb, 2*m_config.searchRadius, true);  
  195.       
  196.     vector<FloatRect> keptRects;  
  197.     keptRects.push_back(rects[0]); // 原始目標框  
  198.     for (int i = 1; i < (int)rects.size(); ++i)  
  199.     {  
  200.         if (!rects[i].IsInside(image.GetRect())) continue;   //判斷生成的樣本框是否超出圖像範圍,超出的捨棄  
  201.         keptRects.push_back(rects[i]);  
  202.     }  
  203.           
  204. #if VERBOSE       
  205.     cout << keptRects.size() << " samples" << endl;  
  206. #endif  
  207.           
  208.     MultiSample sample(image, keptRects);      //多樣本類對象sample,包含ImageRep& image,以及保留下來樣本框  
  209.     m_pLearner->Update(sample, 0);       //更新,在LaRank類下實現  
  210. }  

LaRank.h

[cpp] view plain copy
  1. #ifndef LARANK_H  
  2. #define LARANK_H  
  3.   
  4. #include "Rect.h"  
  5. #include "Sample.h"  
  6.   
  7. #include <vector>  
  8. #include <Eigen/Core>  
  9.   
  10. #include <opencv/cv.h>  
  11.   
  12. class Config;  
  13. class Features;  
  14. class Kernel;  
  15.   
  16. class LaRank   //文獻《Solving multiclass support vector machine with LaRank》,該類實現了struck算法的主要步驟  
  17. {  
  18. public:  
  19.     LaRank(const Config& conf, const Features& features, const Kernel& kernel);  //初始化參數,特徵值,核  
  20.     ~LaRank();  
  21.       
  22.     virtual void Eval(const MultiSample& x, std::vector<double>& results);  
  23.     virtual void Update(const MultiSample& x, int y);  
  24.       
  25.     virtual void Debug();  
  26.   
  27. private:  
  28.   
  29.     struct SupportPattern  
  30.     {  
  31.         std::vector<Eigen::VectorXd> x;   //特徵值  
  32.         std::vector<FloatRect> yv;        //變化關係  
  33.         std::vector<cv::Mat> images;      //圖像片  
  34.         int y;                            //索引值  
  35.         int refCount;                    //統計sp的個數?  
  36.     };  
  37.   
  38.     struct SupportVector  
  39.     {  
  40.         SupportPattern* x;  
  41.         int y;  
  42.         double b;                //beta  
  43.         double g;                 //gradient  
  44.         cv::Mat image;  
  45.     };  
  46.       
  47.     const Config& m_config;  
  48.     const Features& m_features;  
  49.     const Kernel& m_kernel;  
  50.       
  51.     std::vector<SupportPattern*> m_sps;  
  52.     std::vector<SupportVector*> m_svs;  
  53.   
  54.     cv::Mat m_debugImage;  
  55.       
  56.     double m_C;  
  57.     Eigen::MatrixXd m_K;  
  58.   
  59.     inline double Loss(const FloatRect& y1, const FloatRect& y2) const         //損失函數  
  60.     {  
  61.         // overlap loss  
  62.         return 1.0-y1.Overlap(y2);  
  63.         // squared distance loss  
  64.         //double dx = y1.XMin()-y2.XMin();  
  65.         //double dy = y1.YMin()-y2.YMin();  
  66.         //return dx*dx+dy*dy;  
  67.     }  
  68.       
  69.     double ComputeDual() const;  
  70.   
  71.     void SMOStep(int ipos, int ineg);  
  72.     std::pair<intdouble> MinGradient(int ind);  
  73.     void ProcessNew(int ind);  
  74.     void Reprocess();  
  75.     void ProcessOld();  
  76.     void Optimize();  
  77.   
  78.     int AddSupportVector(SupportPattern* x, int y, double g);  
  79.     void RemoveSupportVector(int ind);  
  80.     void RemoveSupportVectors(int ind1, int ind2);  
  81.     void SwapSupportVectors(int ind1, int ind2);  
  82.       
  83.     void BudgetMaintenance();  
  84.     void BudgetMaintenanceRemove();  
  85.   
  86.     double Evaluate(const Eigen::VectorXd& x, const FloatRect& y) const;  
  87.     void UpdateDebugImage();  
  88. };  
  89.   
  90. #endif  

LaRank.cpp

[cpp] view plain copy
  1. #include "LaRank.h"  
  2.   
  3. #include "Config.h"  
  4. #include "Features.h"  
  5. #include "Kernels.h"  
  6. #include "Sample.h"  
  7. #include "Rect.h"  
  8. #include "GraphUtils/GraphUtils.h"  
  9.   
  10. #include <Eigen/Array>  
  11.   
  12. #include <opencv/highgui.h>  
  13. static const int kTileSize = 30;  
  14. using namespace cv;  
  15.   
  16. using namespace std;  
  17. using namespace Eigen;  
  18.   
  19. static const int kMaxSVs = 2000; // TODO (only used when no budget)  
  20.   
  21.   
  22. LaRank::LaRank(const Config& conf, const Features& features, const Kernel& kernel) :  
  23.     m_config(conf),  
  24.     m_features(features),  
  25.     m_kernel(kernel),  
  26.     m_C(conf.svmC)  
  27. {  
  28.     int N = conf.svmBudgetSize > 0 ? conf.svmBudgetSize+2 : kMaxSVs;     //N=100+2,特徵向量的個數不能超過這個閾值  
  29.     m_K = MatrixXd::Zero(N, N);            //m_K表示核矩陣,102*102  
  30.     m_debugImage = Mat(800, 600, CV_8UC3);  
  31. }  
  32.   
  33. LaRank::~LaRank()  
  34. {  
  35. }  
  36.   
  37. double LaRank::Evaluate(const Eigen::VectorXd& x, const FloatRect& y) const  //論文中公式10後半部分計算,即F  
  38. {  
  39.     double f = 0.0;  
  40.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  41.     {  
  42.         const SupportVector& sv = *m_svs[i];  
  43.         f += sv.b*m_kernel.Eval(x, sv.x->x[sv.y]);       //beta*高斯核  
  44.     }  
  45.     return f;  
  46. }  
  47.   
  48. void LaRank::Eval(const MultiSample& sample, std::vector<double>& results)  
  49. {  
  50.     const FloatRect& centre(sample.GetRects()[0]);       //原始目標框  
  51.     vector<VectorXd> fvs;  
  52.     const_cast<Features&>(m_features).Eval(sample, fvs);     //fvs存放haar特徵值  
  53.     results.resize(fvs.size());  
  54.     for (int i = 0; i < (int)fvs.size(); ++i)  
  55.     {  
  56.         // express y in coord frame of centre sample  
  57.         FloatRect y(sample.GetRects()[i]);  
  58.         y.Translate(-centre.XMin(), -centre.YMin());     //將每個框的橫縱座標分別減去原始目標框的橫縱座標  
  59.         results[i] = Evaluate(fvs[i], y);         //計算每個框的F函數,結果保存在results中。  
  60.     }  
  61. }  
  62.   
  63. void LaRank::Update(const MultiSample& sample, int y)  
  64. {  
  65.     // add new support pattern  
  66.     SupportPattern* sp = new SupportPattern;        //定義一個sp  
  67.     const vector<FloatRect>& rects = sample.GetRects();      //獲得所有的樣本框  
  68.     FloatRect centre = rects[y];                     //原始目標框  
  69.     for (int i = 0; i < (int)rects.size(); ++i)  
  70.     {  
  71.         // express r in coord frame of centre sample  
  72.         FloatRect r = rects[i];  
  73.         r.Translate(-centre.XMin(), -centre.YMin());   //這就表示幀間目標位置變化關係  
  74.         sp->yv.push_back(r);  
  75.         if (!m_config.quietMode && m_config.debugMode)  
  76.         {  
  77.             // store a thumbnail for each sample  
  78.             Mat im(kTileSize, kTileSize, CV_8UC1);  
  79.             IntRect rect = rects[i];  
  80.             cv::Rect roi(rect.XMin(), rect.YMin(), rect.Width(), rect.Height());  //感興趣的區域是那些抽取的樣本區域  
  81.             cv::resize(sample.GetImage().GetImage(0)(roi), im, im.size());       //0表示通道數,將感興趣區域統一爲30*30,並保存在sp裏的images  
  82.             sp->images.push_back(im);  
  83.         }  
  84.     }  
  85.     // evaluate features for each sample  
  86.     sp->x.resize(rects.size());    //有多少個感興趣的框,就有多少個特徵值向量。  
  87.     const_cast<Features&>(m_features).Eval(sample, sp->x);    //將每個樣本框計算得到的haar特徵存入sp->x,這裏關於haar特徵的代碼不再列出,我將代碼提取出來單獨寫出一篇博客《http://blog.csdn.net/qianxin_dh/article/details/39268113》  
  88.     sp->y = y;  
  89.     sp->refCount = 0;  
  90.     m_sps.push_back(sp);   //存儲sp  
  91.   
  92.     ProcessNew((int)m_sps.size()-1);  //執行該步驟,添加支持向量,並對beta值進行調整  
  93.     BudgetMaintenance();       //保證支持向量沒有超出限定閾值  
  94.       
  95.     for (int i = 0; i < 10; ++i)  
  96.     {  
  97.         Reprocess();           //包括processold:增加新的sv;optimize:在現有的sv基礎上調整beta值  
  98.         BudgetMaintenance();  
  99.     }  
  100. }  
  101.   
  102. void LaRank::BudgetMaintenance()  
  103. {  
  104.     if (m_config.svmBudgetSize > 0)  
  105.     {  
  106.         while ((int)m_svs.size() > m_config.svmBudgetSize)  
  107.         {  
  108.             BudgetMaintenanceRemove();  //支持向量的個數超出閾值後,找到對於F函數影響最小的負sv,並移除。  
  109.         }  
  110.     }  
  111. }  
  112.   
  113. void LaRank::Reprocess()  
  114. {  
  115.     ProcessOld();       //每個processold步驟伴隨着10個optimize步驟。  
  116.     for (int i = 0; i < 10; ++i)  
  117.     {  
  118.         Optimize();  
  119.     }  
  120. }  
  121.   
  122. double LaRank::ComputeDual() const  
  123. {  
  124.     double d = 0.0;  
  125.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  126.     {  
  127.         const SupportVector* sv = m_svs[i];  
  128.         d -= sv->b*Loss(sv->x->yv[sv->y], sv->x->yv[sv->x->y]);  
  129.         for (int j = 0; j < (int)m_svs.size(); ++j)  
  130.         {  
  131.             d -= 0.5*sv->b*m_svs[j]->b*m_K(i,j);  
  132.         }  
  133.     }  
  134.     return d;  
  135. }  
  136.   
  137. void LaRank::SMOStep(int ipos, int ineg)  
  138. {  
  139.     if (ipos == ineg) return;  
  140.   
  141.     SupportVector* svp = m_svs[ipos];    //定義一個正支持向量  
  142.     SupportVector* svn = m_svs[ineg];    //定義一個負支持向量  
  143.     assert(svp->x == svn->x);  
  144.     SupportPattern* sp = svp->x;    //定義一個支持模式sp,將正支持向量的支持模式賦予sp  
  145.   
  146. #if VERBOSE  
  147.     cout << "SMO: gpos:" << svp->g << " gneg:" << svn->g << endl;  
  148. #endif    
  149.     if ((svp->g - svn->g) < 1e-5)  
  150.     {  
  151. #if VERBOSE  
  152.         cout << "SMO: skipping" << endl;  
  153. #endif        
  154.     }  
  155.     else  
  156.     {   //論文中的Algorithm步驟  
  157.         double kii = m_K(ipos, ipos) + m_K(ineg, ineg) - 2*m_K(ipos, ineg);  
  158.         double lu = (svp->g-svn->g)/kii;  
  159.         // no need to clamp against 0 since we'd have skipped in that case  
  160.         double l = min(lu, m_C*(int)(svp->y == sp->y) - svp->b);  
  161.   
  162.         svp->b += l;  
  163.         svn->b -= l;  
  164.   
  165.         // update gradients  
  166.         for (int i = 0; i < (int)m_svs.size(); ++i)  
  167.         {  
  168.             SupportVector* svi = m_svs[i];  
  169.             svi->g -= l*(m_K(i, ipos) - m_K(i, ineg));  
  170.         }  
  171. #if VERBOSE  
  172.         cout << "SMO: " << ipos << "," << ineg << " -- " << svp->b << "," << svn->b << " (" << l << ")" << endl;  
  173. #endif        
  174.     }  
  175.       
  176.     // check if we should remove either sv now  
  177.       
  178.     if (fabs(svp->b) < 1e-8)         //beta爲0,該向量被移除  
  179.     {  
  180.         RemoveSupportVector(ipos);  
  181.         if (ineg == (int)m_svs.size())  
  182.         {  
  183.             // ineg and ipos will have been swapped during sv removal  
  184.             ineg = ipos;  
  185.         }  
  186.     }  
  187.   
  188.     if (fabs(svn->b) < 1e-8)  //beta=0,該向量被移除  
  189.     {  
  190.         RemoveSupportVector(ineg);  
  191.     }  
  192. }  
  193.   
  194. pair<intdouble> LaRank::MinGradient(int ind)  
  195. {  
  196.     const SupportPattern* sp = m_sps[ind];  
  197.     pair<intdouble> minGrad(-1, DBL_MAX);  
  198.     for (int i = 0; i < (int)sp->yv.size(); ++i)  
  199.     {  
  200.         double grad = -Loss(sp->yv[i], sp->yv[sp->y]) - Evaluate(sp->x[i], sp->yv[i]);//通過公式10找到最小梯度對應的樣本框  
  201.         if (grad < minGrad.second)  
  202.         {  
  203.             minGrad.first = i;  
  204.             minGrad.second = grad;  
  205.         }  
  206.     }  
  207.     return minGrad;  
  208. }  
  209.   
  210. void LaRank::ProcessNew(int ind)  //可以添加新的支持向量,增加的正負支持向量(sv)具有相同的支持模式(sp)  
  211. {  
  212.     // gradient is -f(x,y) since loss=0  
  213.     int ip = AddSupportVector(m_sps[ind], m_sps[ind]->y, -Evaluate(m_sps[ind]->x[m_sps[ind]->y],m_sps[ind]->yv[m_sps[ind]->y]));  //處理當前新樣本,將上一幀目標位置作爲正向量加入  
  214.   
  215.     pair<intdouble> minGrad = MinGradient(ind);  //int,double分別是具有最小梯度的樣本框存放的位置,最小梯度的數值  
  216.     int in = AddSupportVector(m_sps[ind], minGrad.first, minGrad.second);    //將當前具有最小梯度的樣本作爲負向量加入  
  217.   
  218.     SMOStep(ip, in);   //Algorithm 1,更新beta和gradient值  
  219. }  
  220.   
  221. void LaRank::ProcessOld()  //可以添加新的支持向量  
  222. {  
  223.     if (m_sps.size() == 0) return;  
  224.   
  225.     // choose pattern to process  
  226.     int ind = rand() % m_sps.size();   //隨機選取sp  
  227.   
  228.     // find existing sv with largest grad and nonzero beta  
  229.     int ip = -1;  
  230.     double maxGrad = -DBL_MAX;  
  231.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  232.     {  
  233.         if (m_svs[i]->x != m_sps[ind]) continue;  
  234.   
  235.         const SupportVector* svi = m_svs[i];  
  236.         if (svi->g > maxGrad && svi->b < m_C*(int)(svi->y == m_sps[ind]->y))   //找出符合該條件的,作爲y+,後一個條件保證了y+是從現存的sv中找出,因此不會增加新的向量  
  237.         {  
  238.             ip = i;  
  239.             maxGrad = svi->g;  
  240.         }  
  241.     }  
  242.     assert(ip != -1);  
  243.     if (ip == -1) return;  
  244.   
  245.     // find potentially new sv with smallest grad  
  246.     pair<intdouble> minGrad = MinGradient(ind);  
  247.     int in = -1;  
  248.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  249.     {  
  250.         if (m_svs[i]->x != m_sps[ind]) continue;              //找出滿足該條件的,作爲y-  
  251.   
  252.         if (m_svs[i]->y == minGrad.first)  
  253.         {  
  254.             in = i;  
  255.             break;  
  256.         }  
  257.     }  
  258.     if (in == -1)  
  259.     {  
  260.         // add new sv  
  261.         in = AddSupportVector(m_sps[ind], minGrad.first, minGrad.second);  //將該樣本作爲負sv加入  
  262.     }  
  263.   
  264.     SMOStep(ip, in);    //更新beta和gradient的值  
  265. }  
  266.   
  267. void LaRank::Optimize()    //  
  268. {  
  269.     if (m_sps.size() == 0) return;  
  270.       
  271.     // choose pattern to optimize  
  272.     int ind = rand() % m_sps.size();   //隨機處理現存的sp  
  273.   
  274.     int ip = -1;  
  275.     int in = -1;  
  276.     double maxGrad = -DBL_MAX;  
  277.     double minGrad = DBL_MAX;  
  278.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  279.     {  
  280.         if (m_svs[i]->x != m_sps[ind]) continue;  
  281.   
  282.         const SupportVector* svi = m_svs[i];  
  283.         if(svi->g>maxGrad&&svi->b<m_C*(int)(svi->y==m_sps->[y]))   //將滿足該條件的作爲y+  
  284.         {  
  285.             ip = i;  
  286.             maxGrad = svi->g;  
  287.         }  
  288.         if (svi->g < minGrad)                       //將滿足該條件的作爲y-  
  289.         {  
  290.             in = i;  
  291.             minGrad = svi->g;  
  292.         }  
  293.     }  
  294.     assert(ip != -1 && in != -1);  
  295.     if (ip == -1 || in == -1)  
  296.     {  
  297.         // this shouldn't happen  
  298.         cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << endl;  
  299.         return;  
  300.     }  
  301.   
  302.     SMOStep(ip, in);         //更新beta和gradient  
  303. }  
  304.   
  305. int LaRank::AddSupportVector(SupportPattern* x, int y, double g)  
  306. {  
  307.     SupportVector* sv = new SupportVector;  
  308.     sv->b = 0.0;        //beta初始化爲0  
  309.     sv->x = x;  
  310.     sv->y = y;  
  311.     sv->g = g;  
  312.   
  313.     int ind = (int)m_svs.size();  
  314.     m_svs.push_back(sv);  
  315.     x->refCount++;  
  316.   
  317. #if VERBOSE  
  318.     cout << "Adding SV: " << ind << endl;  
  319. #endif  
  320.   
  321.     // update kernel matrix  
  322.     for (int i = 0; i < ind; ++i)    //計算核矩陣  
  323.     {  
  324.         m_K(i,ind) = m_kernel.Eval(m_svs[i]->x->x[m_svs[i]->y], x->x[y]);  
  325.         m_K(ind,i) = m_K(i,ind);  
  326.     }  
  327.     m_K(ind,ind) = m_kernel.Eval(x->x[y]);  
  328.   
  329.     return ind;  
  330. }  
  331.   
  332. void LaRank::SwapSupportVectors(int ind1, int ind2)  
  333. {  
  334.     SupportVector* tmp = m_svs[ind1];  
  335.     m_svs[ind1] = m_svs[ind2];  
  336.     m_svs[ind2] = tmp;  
  337.       
  338.     VectorXd row1 = m_K.row(ind1);  
  339.     m_K.row(ind1) = m_K.row(ind2);  
  340.     m_K.row(ind2) = row1;  
  341.       
  342.     VectorXd col1 = m_K.col(ind1);  
  343.     m_K.col(ind1) = m_K.col(ind2);  
  344.     m_K.col(ind2) = col1;  
  345. }  
  346.   
  347. void LaRank::RemoveSupportVector(int ind)  
  348. {  
  349. #if VERBOSE  
  350.     cout << "Removing SV: " << ind << endl;  
  351. #endif    
  352.   
  353.     m_svs[ind]->x->refCount--;  
  354.     if (m_svs[ind]->x->refCount == 0)  
  355.     {  
  356.         // also remove the support pattern  
  357.         for (int i = 0; i < (int)m_sps.size(); ++i)  
  358.         {  
  359.             if (m_sps[i] == m_svs[ind]->x)  
  360.             {  
  361.                 delete m_sps[i];  
  362.                 m_sps.erase(m_sps.begin()+i);  
  363.                 break;  
  364.             }  
  365.         }  
  366.     }  
  367.   
  368.     // make sure the support vector is at the back, this  
  369.     // lets us keep the kernel matrix cached and valid  
  370.     if (ind < (int)m_svs.size()-1)  
  371.     {  
  372.         SwapSupportVectors(ind, (int)m_svs.size()-1);  
  373.         ind = (int)m_svs.size()-1;  
  374.     }  
  375.     delete m_svs[ind];  
  376.     m_svs.pop_back();  
  377. }  
  378.   
  379. void LaRank::BudgetMaintenanceRemove()  
  380. {  
  381.     // find negative sv with smallest effect on discriminant function if removed  
  382.     double minVal = DBL_MAX;  
  383.     int in = -1;  
  384.     int ip = -1;  
  385.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  386.     {  
  387.         if (m_svs[i]->b < 0.0)           //找到負sv  
  388.         {  
  389.             // find corresponding positive sv  
  390.             int j = -1;  
  391.             for (int k = 0; k < (int)m_svs.size(); ++k)  
  392.             {  
  393.                 if (m_svs[k]->b > 0.0 && m_svs[k]->x == m_svs[i]->x)   //找到同一支持模式下的正sv  
  394.                 {  
  395.                     j = k;  
  396.                     break;  
  397.                 }  
  398.             }  
  399.             double val = m_svs[i]->b*m_svs[i]->b*(m_K(i,i) + m_K(j,j) - 2.0*m_K(i,j));  
  400.             if (val < minVal)         //找到對F影響最小的sv  
  401.             {  
  402.                 minVal = val;  
  403.                 in = i;  
  404.                 ip = j;  
  405.             }  
  406.         }  
  407.     }  
  408.   
  409.     // adjust weight of positive sv to compensate for removal of negative  
  410.     m_svs[ip]->b += m_svs[in]->b;    //將負sv移除,其相應的beta值需補償到正sv上。  
  411.   
  412.     // remove negative sv  
  413.     RemoveSupportVector(in);  
  414.     if (ip == (int)m_svs.size())  
  415.     {  
  416.         // ip and in will have been swapped during support vector removal  
  417.         ip = in;  
  418.     }  
  419.       
  420.     if (m_svs[ip]->b < 1e-8)     //beta值爲0,移除該向量  
  421.     {  
  422.         // also remove positive sv  
  423.         RemoveSupportVector(ip);  
  424.     }  
  425.   
  426.     // update gradients  
  427.     // TODO: this could be made cheaper by just adjusting incrementally rather than recomputing  
  428.     for (int i = 0; i < (int)m_svs.size(); ++i)  
  429.     {  
  430.         SupportVector& svi = *m_svs[i];  
  431.         svi.g = -Loss(svi.x->yv[svi.y],svi.x->yv[svi.x->y]) - Evaluate(svi.x->x[svi.y], svi.x->yv[svi.y]);  
  432.     }     
  433. }  
  434.   
  435. void LaRank::Debug()  
  436. {  
  437.     cout << m_sps.size() << "/" << m_svs.size() << " support patterns/vectors" << endl;  
  438.     UpdateDebugImage();  
  439.     imshow("learner", m_debugImage);  
  440. }  
  441.   
  442. void LaRank::UpdateDebugImage()    //該函數主要用於樣本顯示,與算法關係不大,這裏不做分析了  
  443. {  
  444.     m_debugImage.setTo(0);  
  445.       
  446.     int n = (int)m_svs.size();  
  447.       
  448.     if (n == 0) return;  
  449.       
  450.     const int kCanvasSize = 600;  
  451.     int gridSize = (int)sqrtf((float)(n-1)) + 1;  
  452.     int tileSize = (int)((float)kCanvasSize/gridSize);  
  453.       
  454.     if (tileSize < 5)  
  455.     {  
  456.         cout << "too many support vectors to display" << endl;  
  457.         return;  
  458.     }  
  459.       
  460.     Mat temp(tileSize, tileSize, CV_8UC1);  
  461.     int x = 0;  
  462.     int y = 0;  
  463.     int ind = 0;  
  464.     float vals[kMaxSVs];  
  465.     memset(vals, 0, sizeof(float)*n);  
  466.     int drawOrder[kMaxSVs];  
  467.       
  468.     for (int set = 0; set < 2; ++set)  
  469.     {  
  470.         for (int i = 0; i < n; ++i)  
  471.         {  
  472.             if (((set == 0) ? 1 : -1)*m_svs[i]->b < 0.0) continue;  
  473.               
  474.             drawOrder[ind] = i;  
  475.             vals[ind] = (float)m_svs[i]->b;  
  476.             ++ind;  
  477.               
  478.             Mat I = m_debugImage(cv::Rect(x, y, tileSize, tileSize));  
  479.             resize(m_svs[i]->x->images[m_svs[i]->y], temp, temp.size());  
  480.             cvtColor(temp, I, CV_GRAY2RGB);  
  481.             double w = 1.0;  
  482.             rectangle(I, Point(0, 0), Point(tileSize-1, tileSize-1), (m_svs[i]->b > 0.0) ? CV_RGB(0, (uchar)(255*w), 0) : CV_RGB((uchar)(255*w), 0, 0), 3);  
  483.             x += tileSize;  
  484.             if ((x+tileSize) > kCanvasSize)  
  485.             {  
  486.                 y += tileSize;  
  487.                 x = 0;  
  488.             }  
  489.         }  
  490.     }  
  491.       
  492.     const int kKernelPixelSize = 2;  
  493.     int kernelSize = kKernelPixelSize*n;  
  494.       
  495.     double kmin = m_K.minCoeff();  
  496.     double kmax = m_K.maxCoeff();  
  497.       
  498.     if (kernelSize < m_debugImage.cols && kernelSize < m_debugImage.rows)  
  499.     {  
  500.         Mat K = m_debugImage(cv::Rect(m_debugImage.cols-kernelSize, m_debugImage.rows-kernelSize, kernelSize, kernelSize));  
  501.         for (int i = 0; i < n; ++i)  
  502.         {  
  503.             for (int j = 0; j < n; ++j)  
  504.             {  
  505.                 Mat Kij = K(cv::Rect(j*kKernelPixelSize, i*kKernelPixelSize, kKernelPixelSize, kKernelPixelSize));  
  506.                 uchar v = (uchar)(255*(m_K(drawOrder[i], drawOrder[j])-kmin)/(kmax-kmin));  
  507.                 Kij.setTo(Scalar(v, v, v));  
  508.             }  
  509.         }  
  510.     }  
  511.     else  
  512.     {  
  513.         kernelSize = 0;  
  514.     }  
  515.       
  516.       
  517.     Mat I = m_debugImage(cv::Rect(0, m_debugImage.rows - 200, m_debugImage.cols-kernelSize, 200));  
  518.     I.setTo(Scalar(255,255,255));  
  519.     IplImage II = I;  
  520.     setGraphColor(0);  
  521.     drawFloatGraph(vals, n, &II, 0.f, 0.f, I.cols, I.rows);  
  522. }  


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章