Opencv研讀筆記:haartraining程序之icvCreateCARTStageClassifier函數詳解(強分類器創建)~

之前介紹了haartraining程序中的cvCreateMTStumpClassifier函數,這個函數的功能是計算最優弱分類器,這篇文章介紹一下自己對haartraining中關於強分類器計算的一些理解,也就是程序中的icvCreateCARTStageClassifier函數。


在給出代碼之前,說幾處自認爲值得說說的問題:

1. 由於haartraining是基於HAAR特徵進行adaboost訓練,對於HAAR特徵的處理比較繁瑣,採用了奇數弱分類器補充針對翻轉特徵最優弱分類器計算的代碼,所以代碼看起來較爲冗長。

2. 創建強分類器時,其中包含有樣本權值的更新,代碼中共提供了四種經典adaboost算法版本,它們是Discrete Adaboost、Real Adaboost、Logit Boost、Gentle Adaboost。每種算法的權值更新策略不同,這方面的知識建議大家下載幾篇博士論文看看,也可以看看我之前發的博客。

http://blog.csdn.net/wsj998689aa/article/details/42242565

3. 代碼較多地通過函數指針的形式(實際上這也是opencv一直常用的手段)對函數進行回調。

4. 小權值樣本需要剔除掉,因爲小權值樣本對訓練結果的影響微乎其微,加了它們反而要耗時不少。這邊有一處需要提醒大家,當前分類器剔除掉的小權值樣本,仍舊參與下一個分類器中樣本權值的剔除,換句話說,每一個弱分類器,其實都要對所有的樣本權值進行排序,所以會造成實際訓練樣本比例出現“跳變”的情況,但是總體走勢還是始終下降的。如下圖所示:


5. 代碼中採用了較多的中間結構體變量,例如CvIntHaarClassifier結構體(用於模擬強分類器結構體CvStageHaarClassifier的父類),CvBoostTrainer結構體(用於初始化,更新樣本權值等)等等,看起來比較繞。

6. 關於弱分類器的創建,事先創建的其實是CART分類器,CART分類器就是一棵樹,每個節點代表一個最優Haar特徵,但是一般程序中的節點個數都設置爲1,所以一個CART就相當於stump了,此外,CART的創建涉及到節點的分裂,通過icvSplitIndicesCallback函數實現。

7. 在創建CART的時候,最優Haar特徵其實就已經被選擇好了,至於下面還有一個stumpConstructor函數,是由於Haar特徵被翻轉而產生了新特徵,所以需要重新尋找最優弱分類器。

8. 有意思的是,函數輸入了最小正檢率和最大誤檢率,前者決定了判定樣本是正類還是負類的閾值,後者決定了強分類器是否能夠收斂,具體來說吧,先計算當前最優弱分類器關於每個正樣本的置信度,然後對置信度進行排序,這個閾值就是基於最小正檢率選擇的置信度,然後在根據這個閾值來計算當前最優弱分類器的誤檢率(計算每個負樣本的置信度),如果大於了輸入的最大誤檢率,那麼OK!!一串弱分類器所構成的強分類器誕生了。


以上說的就是icvCreateCARTStageClassifier中值得注意的幾點,下面上代碼,是根據自己的理解添加的註釋,請各位不吝批評指正哈!

轉載請註明:http://blog.csdn.net/wsj998689aa/article/details/42398235

static
CvIntHaarClassifier* icvCreateCARTStageClassifier( CvHaarTrainingData* data,        // 全部訓練樣本
                                                   CvMat* sampleIdx,                // 實際訓練樣本序列
                                                   CvIntHaarFeatures* haarFeatures, // 全部HAAR特徵
                                                   float minhitrate,    // 最小正檢率(用於確定強分類器閾值)
                                                   float maxfalsealarm, // 最大誤檢率(用於確定是否收斂)            
                                                   int   symmetric,     // HAAR是否對稱
                                                   float weightfraction,    // 樣本剔除比例(用於剔除小權值樣本)
                                                   int numsplits,           // 每個弱分類器特徵個數(一般爲1)
                                                   CvBoostType boosttype,   // adaboost類型
                                                   CvStumpError stumperror, // Discrete AdaBoost中的閾值計算方式
                                                   int maxsplits )          // 弱分類器最大個數
{

#ifdef CV_COL_ARRANGEMENT
    int flags = CV_COL_SAMPLE;
#else
    int flags = CV_ROW_SAMPLE;
#endif

    CvStageHaarClassifier* stage = NULL;                    // 強分類器
    CvBoostTrainer* trainer;                                // 臨時訓練器,用於更新樣本權值
    CvCARTClassifier* cart = NULL;                          // 弱分類器
    CvCARTTrainParams trainParams;                          // 訓練參數
    CvMTStumpTrainParams stumpTrainParams;                  // 弱分類器參數
    //CvMat* trainData = NULL;
    //CvMat* sortedIdx = NULL;
    CvMat eval;                                             // 臨時矩陣
    int n = 0;                                              // 特徵總數
    int m = 0;                                              // 總樣本個數
    int numpos = 0;                                         // 正樣本個數
    int numneg = 0;                                         // 負樣本個數
    int numfalse = 0;                                       // 誤檢樣本個數
    float sum_stage = 0.0F;                                 // 置信度累積和                              
    float threshold = 0.0F;                                 // 強分類器閾值
    float falsealarm = 0.0F;                                // 誤檢率
    
    //CvMat* sampleIdx = NULL;
    CvMat* trimmedIdx;                                      // 剔除小權值之後的樣本序列
    //float* idxdata = NULL;
    //float* tempweights = NULL;
    //int    idxcount = 0;
    CvUserdata userdata;                                    // 訓練數據

    int i = 0;
    int j = 0;
    int idx;
    int numsamples;                                         // 實際樣本個數
    int numtrimmed;                                         // 剔除小權值之後的樣本個數
    
    CvCARTHaarClassifier* classifier;                       // 弱分類器
    CvSeq* seq = NULL;
    CvMemStorage* storage = NULL;
    CvMat* weakTrainVals;                                   // 樣本類別,只有logitboost纔會用到
    float alpha;
    float sumalpha;
    int num_splits;                                         // 弱分類器個數                                    

#ifdef CV_VERBOSE
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
    printf( "|  N |%%SMP|F|  ST.THR |    HR   |    FA   | EXP. ERR|\n" );
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
#endif /* CV_VERBOSE */
    
    n = haarFeatures->count;
    m = data->sum.rows;
    numsamples = (sampleIdx) ? MAX( sampleIdx->rows, sampleIdx->cols ) : m;

    // 樣本與HAAR特徵
    userdata = cvUserdata( data, haarFeatures );


    /* 弱分類參數設置 */
    stumpTrainParams.type = ( boosttype == CV_DABCLASS )
        ? CV_CLASSIFICATION_CLASS : CV_REGRESSION;                              // 分類或者回歸
    stumpTrainParams.error = ( boosttype == CV_LBCLASS || boosttype == CV_GABCLASS )
        ? CV_SQUARE : stumperror;                                               // 弱分類器閾值計算方式
    stumpTrainParams.portion = CV_STUMP_TRAIN_PORTION;                          // 每組特徵個數
    stumpTrainParams.getTrainData = icvGetTrainingDataCallback;                 // 計算樣本的haar值
    stumpTrainParams.numcomp = n;                                               // 特徵個數            
    stumpTrainParams.userdata = &userdata; 
    stumpTrainParams.sortedIdx = data->idxcache;                                // 特徵-樣本序號矩陣(排序之後)


    // 由於參數衆多,所以創建參數結構體
    trainParams.count = numsplits;                                              // 弱分類器特徵樹
    trainParams.stumpTrainParams = (CvClassifierTrainParams*) &stumpTrainParams;// 弱分類參數
    trainParams.stumpConstructor = cvCreateMTStumpClassifier;                   // 篩選最優弱分類器
    trainParams.splitIdx = icvSplitIndicesCallback;                             // CART節點分裂函數
    trainParams.userdata = &userdata;                                               

    // 臨時向量,用於存放樣本haar特徵值
    eval = cvMat( 1, m, CV_32FC1, cvAlloc( sizeof( float ) * m ) );
    
    storage = cvCreateMemStorage();

    // 最優弱分類器存儲序列
    seq = cvCreateSeq( 0, sizeof( *seq ), sizeof( classifier ), storage );

    // 樣本類別,只有logitboost纔會用到
    weakTrainVals = cvCreateMat( 1, m, CV_32FC1 );

    // 初始化樣本類別與權重,weakTrainVals爲{-1, 1},權重都一樣
    trainer = cvBoostStartTraining( &data->cls, weakTrainVals, &data->weights,
                                    sampleIdx, boosttype );
    num_splits = 0;
    sumalpha = 0.0F;
    do
    {     

#ifdef CV_VERBOSE
        int v_wt = 0;
        int v_flipped = 0;
#endif /* CV_VERBOSE */

        // 剔除小權值樣本
        trimmedIdx = cvTrimWeights( &data->weights, sampleIdx, weightfraction );

        // 實際樣本總數
        numtrimmed = (trimmedIdx) ? MAX( trimmedIdx->rows, trimmedIdx->cols ) : m;

#ifdef CV_VERBOSE
        v_wt = 100 * numtrimmed / numsamples;
        v_flipped = 0;

#endif /* CV_VERBOSE */

        // 重要函數,創建CART樹的同時,當前最優弱分類器出爐,一般只有根節點
        cart = (CvCARTClassifier*) cvCreateCARTClassifier( data->valcache,
                        flags,
                        weakTrainVals, 0, 0, 0, trimmedIdx,
                        &(data->weights),
                        (CvClassifierTrainParams*) &trainParams );

        // 創建弱分類器
        classifier = (CvCARTHaarClassifier*) icvCreateCARTHaarClassifier( numsplits );

        // 將CART樹轉化爲弱分類器
        icvInitCARTHaarClassifier( classifier, cart, haarFeatures );

        num_splits += classifier->count;

        cart->release( (CvClassifier**) &cart );
        
        // 爲何一定要在奇數個弱分類器處計算?
        if( symmetric && (seq->total % 2) )
        {
            float normfactor = 0.0F;
            CvStumpClassifier* stump;
            
            /* 翻轉HAAR特徵 */
            for( i = 0; i < classifier->count; i++ )
            {
                if( classifier->feature[i].desc[0] == 'h' )
                {
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width - 
                            classifier->feature[i].rect[j].r.x -
                            classifier->feature[i].rect[j].r.width;                
                    }
                }
                else
                {
                    int tmp = 0;

                    /* (x,y) -> (24-x,y) */
                    /* w -> h; h -> w    */
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width - 
                            classifier->feature[i].rect[j].r.x;
                        CV_SWAP( classifier->feature[i].rect[j].r.width,
                                 classifier->feature[i].rect[j].r.height, tmp );
                    }
                }
            }

            // 轉化爲基於積分圖計算的特徵
            icvConvertToFastHaarFeature( classifier->feature,
                                         classifier->fastfeature,
                                         classifier->count, data->winsize.width + 1 );

            // 爲了驗證最新翻轉特徵是否爲最優特徵
            stumpTrainParams.getTrainData = NULL;
            stumpTrainParams.numcomp = 1;
            stumpTrainParams.userdata = NULL;
            stumpTrainParams.sortedIdx = NULL;

            // 驗證是否新生成的特徵可作爲最優弱分類器
            for( i = 0; i < classifier->count; i++ )
            {
                for( j = 0; j < numtrimmed; j++ )
                {
                    // 獲取訓練樣本
                    idx = icvGetIdxAt( trimmedIdx, j );

                    // 對每個訓練樣本計算Haar特徵
                    eval.data.fl[idx] = cvEvalFastHaarFeature( &classifier->fastfeature[i],
                                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step) ); 

                    // 歸一化因子
                    normfactor = data->normfactor.data.fl[idx];

                    // 對Haar特徵歸一化
                    eval.data.fl[idx] = ( normfactor == 0.0F )
                        ? 0.0F : (eval.data.fl[idx] / normfactor);
                }

                // 計算最優弱分類器
                stump = (CvStumpClassifier*) trainParams.stumpConstructor( &eval,
                    CV_COL_SAMPLE,
                    weakTrainVals, 0, 0, 0, trimmedIdx,
                    &(data->weights),
                    trainParams.stumpTrainParams );
            
                classifier->threshold[i] = stump->threshold;                // 閾值
                if( classifier->left[i] <= 0 )
                {
                    classifier->val[-classifier->left[i]] = stump->left;    // 左分支輸出置信度
                }
                if( classifier->right[i] <= 0 )
                {
                    classifier->val[-classifier->right[i]] = stump->right;  // 右分支輸出置信度
                }

                stump->release( (CvClassifier**) &stump );        
                
            }

            // 還原參數,參數支持cvCreateCARTClassifier函數
            stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
            stumpTrainParams.numcomp = n;
            stumpTrainParams.userdata = &userdata;
            stumpTrainParams.sortedIdx = data->idxcache;

#ifdef CV_VERBOSE
            v_flipped = 1;
#endif /* CV_VERBOSE */

        } /* if symmetric */
        if( trimmedIdx != sampleIdx )
        {
            cvReleaseMat( &trimmedIdx );
            trimmedIdx = NULL;
        }
        
        // 調用icvEvalCARTHaarClassifier函數,計算每個樣本的當前最優弱分類器置信度
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            eval.data.fl[idx] = classifier->eval( (CvIntHaarClassifier*) classifier,
                (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                data->normfactor.data.fl[idx] );
        }

        // 更新樣本權重,如果是LogitBoost,也會更新weakTrainVals,函數返回的是弱分類器權重
        alpha = cvBoostNextWeakClassifier( &eval, &data->cls, weakTrainVals,
                                           &data->weights, trainer );
        
        // 這個變量沒什麼用
        sumalpha += alpha;
        
        for( i = 0; i <= classifier->count; i++ )
        {
            if( boosttype == CV_RABCLASS ) 
            {
                classifier->val[i] = cvLogRatio( classifier->val[i] );
            }
            classifier->val[i] *= alpha;
        }

        // 添加弱分類器
        cvSeqPush( seq, (void*) &classifier );

        // 正樣本個數
        numpos = 0;

        // 遍歷sampleIdx中所有樣本,計算每個樣本的弱分類器置信度和
        for( i = 0; i < numsamples; i++ )
        {
            // 獲得樣本序號
            idx = icvGetIdxAt( sampleIdx, i );

            // 如果樣本爲正樣本
            if( data->cls.data.fl[idx] == 1.0F )
            {
                // 初始化置信度值
                eval.data.fl[numpos] = 0.0F;

                // 遍歷seq中所有弱分類器
                for( j = 0; j < seq->total; j++ )
                {
                    // 獲取弱分類器
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));

                    // 累積當前正樣本的弱分類器置信度和
                    eval.data.fl[numpos] += classifier->eval( 
                        (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* eval.data.fl[numpos] = 2.0F * eval.data.fl[numpos] - seq->total; */
                numpos++;
            }
        }

        // 對弱分類器輸出置信度和進行排序
        icvSort_32f( eval.data.fl, numpos, 0 );

        // 計算閾值,應該是大於threshold則爲正類,小於threshold則爲負類
        threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)];

        numneg = 0;
        numfalse = 0;

        // 遍歷所有樣本,統計錯分負樣本個數
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            // 如果樣本爲負樣本
            if( data->cls.data.fl[idx] == 0.0F )
            {
                numneg++;
                sum_stage = 0.0F;

                // 遍歷seq中所有弱分類器
                for( j = 0; j < seq->total; j++ )
                {
                   classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));

                   // 累積當前負樣本的分類器輸出結果
                   sum_stage += classifier->eval( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* sum_stage = 2.0F * sum_stage - seq->total; */

                // 因爲小於threshold爲負類,所以下面是分類錯誤的情況
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    numfalse++;
                }
            }
        }

        // 計算虛警率
        falsealarm = ((float) numfalse) / ((float) numneg);

// 輸出內容
#ifdef CV_VERBOSE
        {
            // 正樣本檢出率
            float v_hitrate    = 0.0F;

            // 負樣本誤檢率
            float v_falsealarm = 0.0F;
            /* expected error of stage classifier regardless threshold */

            // 這是什麼?
            float v_experr = 0.0F;

            // 遍歷所有樣本
            for( i = 0; i < numsamples; i++ )
            {
                idx = icvGetIdxAt( sampleIdx, i );

                sum_stage = 0.0F;

                // 遍歷seq中所有弱分類器
                for( j = 0; j < seq->total; j++ )
                {
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));
                    sum_stage += classifier->eval( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* sum_stage = 2.0F * sum_stage - seq->total; */

                // 只需要判斷單一分支即可
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    if( data->cls.data.fl[idx] == 1.0F )
                    {
                        v_hitrate += 1.0F;
                    }
                    else
                    {
                        v_falsealarm += 1.0F;
                    }
                }

                // 正類樣本的sum_stage必須大於0
                if( ( sum_stage >= 0.0F ) != (data->cls.data.fl[idx] == 1.0F) )
                {
                    v_experr += 1.0F;
                }
            }
            v_experr /= numsamples;
            printf( "|%4d|%3d%%|%c|%9f|%9f|%9f|%9f|\n",
                seq->total, v_wt, ( (v_flipped) ? '+' : '-' ),
                threshold, v_hitrate / numpos, v_falsealarm / numneg,
                v_experr );
            printf( "+----+----+-+---------+---------+---------+---------+\n" );
            fflush( stdout );
        }
#endif /* CV_VERBOSE */
        
    // 兩種收斂方式,一種是誤檢率小於規定閾值,另一種是弱分類器個數小於規定閾值
    } while( falsealarm > maxfalsealarm && (!maxsplits || (num_splits < maxsplits) ) );
    cvBoostEndTraining( &trainer );

    if( falsealarm > maxfalsealarm )        
    {
        // 如果弱分類器達到上限而收斂,則放棄當前強分類器
        stage = NULL;
    }
    else
    {
        // 創建當前強分類器
        stage = (CvStageHaarClassifier*) icvCreateStageHaarClassifier( seq->total,
                                                                       threshold );
        // 保存當前強分類器
        cvCvtSeqToArray( seq, (CvArr*) stage->classifier );
    }
    
    /* CLEANUP */
    cvReleaseMemStorage( &storage );
    cvReleaseMat( &weakTrainVals );
    cvFree( &(eval.data.ptr) );
    
    return (CvIntHaarClassifier*) stage;
}



發佈了42 篇原創文章 · 獲贊 193 · 訪問量 50萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章