Opencv研讀筆記:haartraining程序之cvCreateMTStumpClassifier函數詳解(弱分類器創建)~

cvCreateMTStumpClassifier函數出自opencv中的haartraining程序,在Haartraining中強分類器創建函數icvCreateCARTStageClassifier中被兩次調用,該函數用於尋找最優弱分類器,或者說成計算最優Haar特徵。功能很明確,但是大家都知道的,opencv的代碼絕大部分寫的讓人真心看不懂,這個函數算是haartraining中比較難以看懂的函數,局部變量達到20個之多,童鞋我也是不甘心,不甘心被這小小的函數所擊潰,於是擦乾淚水,仔細研讀,終於恍然大悟,大徹大悟的同時,不忘回報CSDN博客,與朋友們分享。

1. 最優弱分類器的計算過程,網上到處都有介紹,其實就是個窮舉的過程,先對每個特徵所對應的訓練樣本的特徵值進行排序,然後遍歷每個特徵值作爲閾值,根據特定的方法(1.misclass 2.gini 3.entropy 4.least sum of squares確定最優閾值,進一步確定最優特徵,也就是最優弱分類器了。

2. opencv寫的比較通用,所以有點讓人摸不清頭腦,它是這麼幹的:先預計算“Haar特徵-訓練樣本”矩陣(trainData),一般Haar特徵有600個,然後先尋找這600個特徵中的最優特徵,但是總共的Haar特徵可能有1萬多個,對於新特徵那就只能在重新計算訓練樣本升序矩陣(mat),繼續尋找最優特徵。

3. 由於程序一直躍躍欲試,想用並行的方法處理,導致了程序的局部變量增多(例如portion的引用),這是值得大家注意的地方。

4. 另外,上面說到cvCreateMTStumpClassifier函數被兩次調用,一次是在cvCreateCARTClassifier中,一次是在icvCreateCARTStageClassifier中,其中,前者中trainData對應的是一個矩陣,而後者trainData對應的是一個行向量。

注意上面幾處,再看源碼,應該就不會被弄暈了,我直接上代碼,並且做了比較詳細的註釋,這樣子更加實在一些,希望能夠對童鞋們有所幫助!

(轉載請註明:http://blog.csdn.net/wsj998689aa/article/details/42294703,作者:迷霧forest)


// 函數功能:計算最優弱分類器
CV_BOOST_IMPL
CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,      // 訓練樣本HAAR特徵值矩陣
                      int flags,                                // 1.按行排列,0.按列排列
                      CvMat* trainClasses,                      // 樣本類別{-1,1}
                      CvMat* /*typeMask*/,                      // 爲了便於回調函數統一格式
                      CvMat* missedMeasurementsMask,            // 未知,很少用到
                      CvMat* compIdx,                           // 特徵序列(必須爲NULL)(行向量)
                      CvMat* sampleIdx,                         // 實際訓練樣本序列(行向量)
                      CvMat* weights,                           // 實際訓練樣本樣本權重(行向量)
                      CvClassifierTrainParams* trainParams )    // 其它數據&參數
{
    CvStumpClassifier* stump = NULL;        // 弱分類器(樁)
    int m = 0;                              // 樣本總數
    int n = 0;                              // 所有特徵個數   
    uchar* data = NULL;                     // trainData數據指針
    size_t cstep   = 0;                     // trainData一行字節數
    size_t sstep   = 0;                     // trainData元素字節數
    int    datan   = 0;                     // 預計算特徵個數
    uchar* ydata = NULL;                    // trainClasses數據指針
    size_t ystep = 0;                       // trainClasses元素字節數
    uchar* idxdata = NULL;                  // sampleIdx數據指針
    size_t idxstep = 0;                     // sampleIdx單個元素字節數
    int    l = 0;                           // 實際訓練樣本個數    
    uchar* wdata = NULL;                    // weights數據指針
    size_t wstep = 0;                       // weights元素字節數

    /* sortedIdx爲事先計算好的特徵值-樣本矩陣,包含有預計算的所有HAAR特徵對應於所有樣本的特徵值(按大小排列) */
    uchar* sorteddata = NULL;               // sortedIdx數據指針
    int    sortedtype    = 0;               // sortedIdx元素類型
    size_t sortedcstep   = 0;               // sortedIdx一行字節數
    size_t sortedsstep   = 0;               // sortedIdx元素字節數
    int    sortedn       = 0;               // sortedIdx行數(預計算特徵個數)
    int    sortedm       = 0;               // sortedIdx列數(實際訓練樣本個數)

    char* filter = NULL;                    // 樣本存在標示(行向量),如果樣本存在則爲1,否則爲0
    int i = 0;
    
    int compidx = 0;                        // 每組特徵的起始序號
    int stumperror;                         // 計算閾值方法:1.misclass 2.gini 3.entropy 4.least sum of squares
    int portion;                            // 每組特徵個數,對所有特徵n進行分組處理,每組portion個

    /* private variables */
    CvMat mat;                              // 補充特徵-樣本矩陣
    CvValArray va;
    float lerror;                           // 閾值左側誤差
    float rerror;                           // 閾值右側誤差
    float left;<span style="white-space:pre">			</span>            // 置信度(左分支)
    float right;<span style="white-space:pre">			</span>    // 置信度(右分支)
    float threshold;                        // 閾值
    int optcompidx;                         // 最優特徵

    float sumw;                             
    float sumwy;
    float sumwyy;

    /*臨時變量,循環用*/
    int t_compidx;
    int t_n;
    
    int ti;
    int tj;
    int tk;

    uchar* t_data;                          // 指向data
    size_t t_cstep;                         // cstep
    size_t t_sstep;                         // sstep

    size_t matcstep;                        // mat一行字節數
    size_t matsstep;                        // mat元素字節數

    int* t_idx;                             // 樣本序列
    /* end private variables */

    CV_Assert( trainParams != NULL );
    CV_Assert( trainClasses != NULL );
    CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
    CV_Assert( missedMeasurementsMask == NULL );
    CV_Assert( compIdx == NULL );

    // 計算閾值方法:1.misclass 2.gini 3.entropy 4.least sum of squares
    stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;

    // 樣本類別
    ydata = trainClasses->data.ptr;
    if( trainClasses->rows == 1 )
    {
        m = trainClasses->cols;
        ystep = CV_ELEM_SIZE( trainClasses->type );
    }
    else
    {
        m = trainClasses->rows;
        ystep = trainClasses->step;
    }

    // 樣本權重
    wdata = weights->data.ptr;
    if( weights->rows == 1 )
    {
        CV_Assert( weights->cols == m );
        wstep = CV_ELEM_SIZE( weights->type );
    }
    else
    {
        CV_Assert( weights->rows == m );
        wstep = weights->step;
    }

    // sortedIdx爲空,trainData爲行向量(1*m);sortedIdx不爲空,trainData爲矩陣(m*datan);
    if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
    {
        sortedtype =
            CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
        assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
                || sortedtype == CV_32FC1 );
        sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
        sortedsstep = CV_ELEM_SIZE( sortedtype );
        sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
        sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
        sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
    }

    if( trainData == NULL )                         // 爲空的情況沒有遇到
    {
        assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
        n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
        assert( n > 0 );
    }
    else
    {
        assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
        data = trainData->data.ptr;
        if( CV_IS_ROW_SAMPLE( flags ) )             // trainData爲矩陣
        {
            cstep = CV_ELEM_SIZE( trainData->type );
            sstep = trainData->step;
            assert( m == trainData->rows );
            datan = n = trainData->cols;
        }
        else                                        // trainData爲向量
        {
            sstep = CV_ELEM_SIZE( trainData->type );
            cstep = trainData->step;
            assert( m == trainData->cols );
            datan = n = trainData->rows;
        }

        // trainData爲矩陣,當trainData爲向量時,datan = n = 1
        if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
        {
            n = ((CvMTStumpTrainParams*) trainParams)->numcomp;     // 總特徵個數  
        }        
    }

    // 預計算特徵個數一定要小於特徵總數
    assert( datan <= n );

    if( sampleIdx != NULL )     // 已經剔除小權值樣本
    {
        assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
        idxdata = sampleIdx->data.ptr;
        idxstep = ( sampleIdx->rows == 1 )
            ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
        l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;

        // sorteddata中存放的是所有訓練樣本,需要篩選出實際訓練樣本
        if( sorteddata != NULL )
        {
            filter = (char*) cvAlloc( sizeof( char ) * m );
            memset( (void*) filter, 0, sizeof( char ) * m );
            for( i = 0; i < l; i++ )
            {
                filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;   // 存在則爲1,不存在則爲0
            }
        }
    }
    else                        // 未剔除小權值樣本
    {
        l = m;
    }

    // 樁
    stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
    memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );

    // 每組特徵個數
    portion = ((CvMTStumpTrainParams*)trainParams)->portion;
    
    if( portion < 1 )
    {
        /* auto portion */
        portion = n;
        #ifdef _OPENMP
        portion /= omp_get_max_threads();        
        #endif /* _OPENMP */        
    }

    stump->eval = cvEvalStumpClassifier;
    stump->tune = NULL;
    stump->save = NULL;
    stump->release = cvReleaseStumpClassifier;

    stump->lerror = FLT_MAX;
    stump->rerror = FLT_MAX;
    stump->left  = 0.0F;
    stump->right = 0.0F;

    compidx = 0;

    // 並行計算,默認爲關閉的
    #ifdef _OPENMP
    #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
                                 optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
                                 ti, tj, tk, t_data, t_cstep, t_sstep, matcstep,  \
                                 matsstep, t_idx)
    #endif /* _OPENMP */
    {
        lerror = FLT_MAX;
        rerror = FLT_MAX;
        left  = 0.0F;
        right = 0.0F;
        threshold = 0.0F;
        optcompidx = 0;

        sumw   = FLT_MAX;
        sumwy  = FLT_MAX;
        sumwyy = FLT_MAX;

        t_compidx = 0;
        t_n = 0;
        
        ti = 0;
        tj = 0;
        tk = 0;

        t_data = NULL;
        t_cstep = 0;
        t_sstep = 0;

        matcstep = 0;
        matsstep = 0;

        t_idx = NULL;

        mat.data.ptr = NULL;
        
        // 預計算特徵個數小於特徵總數,則說明存在新特徵,用於計算樣本的新特徵,存放在mat中
        if( datan < n )
        {
            if( CV_IS_ROW_SAMPLE( flags ) )
            {
                mat = cvMat( m, portion, CV_32FC1, 0 );
                matcstep = CV_ELEM_SIZE( mat.type );
                matsstep = mat.step;
            }
            else
            {
                mat = cvMat( portion, m, CV_32FC1, 0 );
                matcstep = mat.step;
                matsstep = CV_ELEM_SIZE( mat.type );
            }
            mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
        }

        // 將實際訓練樣本序列存放進t_idx
        if( filter != NULL || sortedn < n )
        {
            t_idx = (int*) cvAlloc( sizeof( int ) * m );
            if( sortedn == 0 || filter == NULL )
            {
                if( idxdata != NULL )
                {
                    for( ti = 0; ti < l; ti++ )
                    {
                        t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
                    }
                }
                else
                {
                    for( ti = 0; ti < l; ti++ )
                    {
                        t_idx[ti] = ti;
                    }
                }                
            }
        }

        #ifdef _OPENMP
        #pragma omp critical(c_compidx)
        #endif /* _OPENMP */

        // 初始化計算特徵範圍
        {
            t_compidx = compidx;
            compidx += portion;
        }

        // 尋找最優弱分類器
        while( t_compidx < n )
        {
            t_n = portion;                      // 每組特徵個數
            if( t_compidx < datan )             // 已經計算過的特徵
            {
                t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
                t_data = data;
                t_cstep = cstep;
                t_sstep = sstep;
            }
            else                                // 新特徵
            {
                t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
                t_cstep = matcstep;
                t_sstep = matsstep;
                t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );

                // 計算每個新特徵對應於每個訓練樣本的特徵值
                ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
                        sampleIdx, compIdx, t_compidx, t_n,
                        ((CvMTStumpTrainParams*)trainParams)->userdata );
            }

            /* 預計算特徵部分,直接尋找最優特徵,也就是傳說中的最優弱分類器 */
            if( sorteddata != NULL )
            {
                if( filter != NULL )    // 需要提取實際訓練樣本
                {
                    switch( sortedtype )
                    {
                        case CV_16SC1:	// 這裏重複度很高,只註釋一個分支,剩下的都一個道理

                            // 從一組特徵(datan個預計算特徵)中尋找最優特徵
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;

                                // 提取實際訓練樣本
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((short*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }

                                // 如果findStumpThreshold_32s返回值爲1, 則更新最優特徵
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32SC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((int*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32FC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                tk = 0;
                                for( tj = 0; tj < sortedm; tj++ )
                                {
                                    int curidx = (int) ( *((float*) (sorteddata
                                            + ti * sortedcstep + tj * sortedsstep)) );
                                    if( filter[curidx] != 0 )
                                    {
                                        t_idx[tk++] = curidx;
                                    }
                                }
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        (uchar*) t_idx, sizeof( int ), tk,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        default:
                            assert( 0 );
                            break;
                    }
                }
                else            // 所有訓練樣本均參與計算
                {
                    switch( sortedtype )
                    {
                        case CV_16SC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                if( findStumpThreshold_16s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32SC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                if( findStumpThreshold_32s[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        case CV_32FC1:
                            for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
                            {
                                if( findStumpThreshold_32f[stumperror]( 
                                        t_data + ti * t_cstep, t_sstep,
                                        wdata, wstep, ydata, ystep,
                                        sorteddata + ti * sortedcstep, sortedsstep, sortedm,
                                        &lerror, &rerror,
                                        &threshold, &left, &right, 
                                        &sumw, &sumwy, &sumwyy ) )
                                {
                                    optcompidx = ti;
                                }
                            }
                            break;
                        default:
                            assert( 0 );
                            break;
                    }
                }
            }

            /* 新特徵部分,要對樣本特徵值進行排序,然後再尋找最優特徵 */
            ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
            for( ; ti < t_compidx + t_n; ti++ )
            {
                va.data = t_data + ti * t_cstep;
                va.step = t_sstep;

                // 對樣本特徵值進行排序
                icvSortIndexedValArray_32s( t_idx, l, &va );

                // 繼續尋找最優特徵
                if( findStumpThreshold_32s[stumperror]( 
                        t_data + ti * t_cstep, t_sstep,
                        wdata, wstep, ydata, ystep,
                        (uchar*)t_idx, sizeof( int ), l,
                        &lerror, &rerror,
                        &threshold, &left, &right, 
                        &sumw, &sumwy, &sumwyy ) )
                {
                    optcompidx = ti;
                }
            }
            #ifdef _OPENMP
            #pragma omp critical(c_compidx)
            #endif /* _OPENMP */

            // 更新特徵計算範圍
            {
                t_compidx = compidx;
                compidx += portion;
            }
        }

        #ifdef _OPENMP
        #pragma omp critical(c_beststump)
        #endif /* _OPENMP */

        // 設置最優弱分類器
        {
            if( lerror + rerror < stump->lerror + stump->rerror )
            {
                stump->lerror    = lerror;
                stump->rerror    = rerror;
                stump->compidx   = optcompidx;
                stump->threshold = threshold;
                stump->left      = left;
                stump->right     = right;
            }
        }

        /* free allocated memory */
        if( mat.data.ptr != NULL )
        {
            cvFree( &(mat.data.ptr) );
        }
        if( t_idx != NULL )
        {
            cvFree( &t_idx );
        }
    } /* end of parallel region */

    /* END */

    /* free allocated memory */
    if( filter != NULL )
    {
        cvFree( &filter );
    }
    // 如果設置爲離散型,置信度應爲1或者-1
    if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
    {
        stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
        stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
    }

    return (CvClassifier*) stump;
}


其實,我現在一直認爲,尋找弱分類器是一個很easy的過程,根本不需要這麼多行代碼,這麼多局部變量,但是仔細閱讀之後發現,opencv還是很牛的,這段代碼的通用性比較強大,兼顧了並行操作可能性。可以應對多特徵弱分類器,代碼結構也是比較爽快的,尤其是其在條件宏、函數指針方面的應用,令人羨慕異常。今後還要繼續研讀opencv代碼,對編程素養的提高,絕對有很大幫助。

如果有啥問題,還請不吝賜教哦!!!

發佈了42 篇原創文章 · 獲贊 193 · 訪問量 50萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章