Opencv研讀筆記:haartraining程序之莫名其妙的條件宏ICV_DEF_FIND_STUMP_THRESHOLD_SQ解釋~

曾經,糾結過haartraining中條件宏ICV_DEF_FIND_STUMP_THRESHOLD_SQ的使用,主要對它的代碼結構和內容不解,針對這個條件宏,自己專門看了Gentle Adaboost的papers,才得以徹底理解代碼含義,我想也有童鞋同樣對這段代碼比較困惑,所以寫下這篇博客,與大家分享。代碼如下所示:

#define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
    ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
        /* calculate error (sum of squares)          */                                  \
        /* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
        curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
        currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
    )


(轉載請註明:http://blog.csdn.net/wsj998689aa/article/details/42242565)

對於這段代碼,最好的方式應該事先了解背後的原理,這段代碼可以說是專門爲Gentle Adaboost服務,Gentle Adaboost是四大Adaboost之一,有興趣的童鞋可以去谷歌學術上看看相關文章,我這裏只給出該段代碼的文字算法流程,大家對照着文字看代碼,就能得到更清晰的理解。



裏面的集中度也就是置信度的意思,ICV_DEF_FIND_STUMP_THRESHOLD定義如下(我根據上面的流程,對代碼進行了註釋):

#define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                        \
        uchar* data, size_t datastep, // 樣本HAAR特徵值                                  \
        uchar* wdata, size_t wstep,   // 樣本權值                                        \
        uchar* ydata, size_t ystep,   // 樣本類別                                        \
        uchar* idxdata, size_t idxstep, int num,   // 實際樣本序列                       \
        float* lerror, //閾值左側錯誤率  						 \
        float* rerror, //閾值右側錯誤率  						 \
        float* threshold, float* left, float* right, // 閾值和左右置信度  		 \
        float* sumw, float* sumwy, float* sumwyy )   // 這個不用解釋了                   \
{                                                                                        \
    int found = 0;                                                                       \
    float wyl  = 0.0F;                                                                   \
    float wl   = 0.0F;   <span style="font-family: Arial, Helvetica, sans-serif; font-size: 12px;">// 閾值左側權值和</span><span style="font-size: 12px; font-family: Arial, Helvetica, sans-serif;">                                                               <span style="white-space:pre">				</span> \</span>
    float wyyl = 0.0F;                                                                    \
    float wyr  = 0.0F;                                                                   \
    float wr   = 0.0F;   <span style="font-size: 12px; font-family: Arial, Helvetica, sans-serif;">// 閾值右側權值和</span><span style="font-size: 12px; font-family: Arial, Helvetica, sans-serif;"> </span>
                                                                 <span style="white-space:pre">			</span> \
                                                                                         \
    float curleft  = 0.0F;     // 左分支置信度                                           \
    float curright = 0.0F;     <span style="font-family: Arial, Helvetica, sans-serif; font-size: 12px;">// 右分支置信度</span>
                                                          \
    float* prevval = NULL;     // 中間值,調試用                                         \
    float* curval  = NULL;                                                               \
    float curlerror = 0.0F;    // 閾值左側錯誤率                                         \
    float currerror = 0.0F;    <span style="font-family: Arial, Helvetica, sans-serif; font-size: 12px;">// 閾值右側錯誤率</span>
                                                         \
    float wposl;                                                                         \
    float wposr;                                                                         \
                                                                                         \
    int i = 0;                                                                           \
    int idx = 0;                                                                         \
                                                                                         \
    wposl = wposr = 0.0F;                                                                \
    if( *sumw == FLT_MAX )                                                               \
    {                                                                                    \
        /* calculate sums */                                                             \
        float *y = NULL;                                                                 \
        float *w = NULL;                                                                 \
        float wy = 0.0F;                                                                 \
                                                                                         \
        *sumw   = 0.0F;                                                                  \
        *sumwy  = 0.0F;                                                                  \
        *sumwyy = 0.0F;                                                                  \
        for( i = 0; i < num; i++ )                                                       \
        {                                                                                \
            idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
            w = (float*) (wdata + idx * wstep);                                          \
            *sumw += *w;                                                                 \
            y = (float*) (ydata + idx * ystep);                                          \
            wy = (*w) * (*y);                                                            \
            *sumwy += wy;                                                                \
            *sumwyy += wy * (*y);                                                        \
        }                                                                                \
    }                                                                                    \
    
    // 遍歷當前特徵值序列的每個元素(閾值),判斷是否存在最優閾值                        \
    for( i = 0; i < num; i++ )                                                           \
    {                                                                                    \
        idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
        curval = (float*) (data + idx * datastep);                                       \
         /* for debug purpose */                                                         \
        if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
                                                                                         \
        wyr  = *sumwy - wyl;                                                             \
        wr   = *sumw  - wl;                                                              \
            
	// 計算置信度,也就是集中度                                                      \
        if( wl > 0.0 ) curleft = wyl / wl;                                               \
        else curleft = 0.0F;                                                             \
                                                                                         \
        if( wr > 0.0 ) curright = wyr / wr;                                              \
        else curright = 0.0F;                                                            \
        
	// 此處爲插入代碼段,計算閾值左右error(<span style="font-family: Arial, Helvetica, sans-serif; font-size: 12px;">curlerror, currerror</span>)
        error                                                                            \
                                         
	// 判斷當前curval,found爲1代表找到最優閾值,意味着當前弱分類器最優
        if( curlerror + currerror < (*lerror) + (*rerror) )                              \
        {                                                                                \
            (*lerror) = curlerror;                                                       \
            (*rerror) = currerror;                                                       \
            *threshold = *curval;                                                        \
            if( i > 0 ) {                                                                \
                *threshold = 0.5F * (*threshold + *prevval);                             \
            }                                                                            \
            *left  = curleft;                                                            \
            *right = curright;                                                           \
            found = 1;                                                                   \
        }                                                                                \
            
	// 計算值curval左側的wl,wyl,wyyl
        do                                                                               \
        {                                                                                \
            wl  += *((float*) (wdata + idx * wstep));                                    \
            wyl += (*((float*) (wdata + idx * wstep)))                                   \
                * (*((float*) (ydata + idx * ystep)));                                   \
            wyyl += *((float*) (wdata + idx * wstep))                                    \
                * (*((float*) (ydata + idx * ystep)))                                    \
                * (*((float*) (ydata + idx * ystep)));                                   \
        }                                                                                \
        while( (++i) < num &&                                                            \
            ( *((float*) (data + (idx =                                                  \
                (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
                == *curval ) );                                                          \
        --i;                                                                             \
        prevval = curval;                                                                \
    } /* for each value */                                                               \
                                                                                         \
    return found;                                                                        \
}


當時覺得很奇怪,內部條件宏ICV_DEF_FIND_STUMP_THRESHOLD括號內明明只有 suffix, type, error 三個參數,怎麼調用傳遞的時候

    ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
        /* calculate error (sum of squares)          */                                  \
        /* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
        curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
        currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \

裏面卻是上面這個樣子,前面兩個參數,外加兩句代碼行?

後來弄明白了,我們注意看ICV_DEF_FIND_STUMP_THRESHOLD中的error的相關使用就知道了,在ICV_DEF_FIND_STUMP_THRESHOLD中間,孤零零的一句代碼

                                                                                         \
        error                                                                            \
                                                                                         \

原來error不能說的上是條件宏的一個參數,他就是一個代碼段,調用的時候,相當於直接把代碼段粘貼到上面的位置!

爲啥要這麼用呢,原因就是error代表的代碼段,複用率十分高,所以索性直接寫成這樣的形式,對於函數之間很像的時候,這樣做是一個不錯的選擇,我們遇到這種情況,一般是新創建一個函數,代碼可能沒有條件宏的方式美觀。看來opencv還是很給力的。

類似的宏還有:

ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )

此外,sq_##suffix起到了連接字符串的功能,如果suffix是16s,那麼sq_##suffix實際上就是sq_16s,再然後,直接指向相關函數。


發佈了42 篇原創文章 · 獲贊 193 · 訪問量 50萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章