基於指令集優化的模型比對程序

/**************************************************************************************************/
/*                                                                                                
/*                                                                                                */
/* 文件名稱:DFR_Compare.h                                                                        */
/* 文件標識:HIK_DFR_LIB_v2                                                                       */
/*                                                                                                */
/*                                                                                                */
/**************************************************************************************************/



#include <math.h>
#include <stdio.h>
#include <string.h>
#include "DFR_Compare.h"
//#include "dl_fr_sigmoid.h"
#include "vca_base.h"
#include "DFR_Common.h"
#include "DFR_Error.h"
//avx
#include <immintrin.h>
#include <mkl.h>
#include "mkl_types.h"
#include "mkl_cblas.h"
#include <ipps.h>
#include <ippi.h>
#include <ippcv.h>


static float FR_CONST_SIGMOID_LUT[SIG_LUT_LEN];
float fr_sigmoid_lut(float val)
{
    float idx_val, tail, new_val;
    int f_idx;

    //靠靠靠
    idx_val = (val - SIG_START_VAL)*LUT_STEP;

    //靠靠靠靠靠靠?
    if ((idx_val < 0) || (idx_val > SIG_LUT_LEN - 1))
    {
        return (idx_val < 0) ? 0.0f : 1.0f;
    }

    f_idx = FR_FLOOR(idx_val);
    tail = idx_val - f_idx;
    new_val = FR_CONST_SIGMOID_LUT[f_idx];
    new_val += tail * (FR_CONST_SIGMOID_LUT[f_idx + 1] - FR_CONST_SIGMOID_LUT[f_idx]);
    return new_val;
}


//// xbyak
typedef struct  _PRODUCT_PARAM_
{
    unsigned char    *feat1;
    unsigned char    *feat2;
    int     *result;
}PRODUCT_PARAM;

// class ProductJit : public Xbyak::CodeGenerator
// {
// public:
//     ProductJit()
//         :CodeGenerator()
//     {
//     }
//
//     void CreateAvx2()
//     {
//         push(rax);
//         push(rbx);
//         push(rcx);
// 
//         mov(rax, ptr[rdi]);
//         mov(rdx, ptr[rdi + 8]);
//         mov(rcx, ptr[rdi + 16]);
// 
//         // 初始化寄存器
//         vpxor(Xmm(1), Xmm(1), Xmm(1));
//         vpxor(Xmm(2), Xmm(2), Xmm(2));
//         vpxor(Ymm(1), Ymm(1), Ymm(1));
//         vpxor(Ymm(2), Ymm(2), Ymm(2));
//         vpxor(Ymm(3), Ymm(3), Ymm(3));
//         vpxor(Ymm(4), Ymm(4), Ymm(4));
// 
//         for (int i = 0; i < (FEAT_DIM >> 4); i++)
//         {
//             vlddqu(xmm1, ptr[rax + i * 16]);
//             vpmovsxbw(ymm1, xmm1);
//             vlddqu(xmm2, ptr[rdx + i * 16]);
//             vpmovsxbw(ymm2, xmm2);
//             // 16爲乘+
//             vpmaddwd(ymm3, ymm2, ymm1);
//             vpaddd(ymm4, ymm3, ymm4);
//         }
// 
//         // 數據保存
//         vmovupd(yword[rcx], ymm4);
// 
//         pop(rcx);
//         pop(rbx);
//         pop(rax);
//         ret();
//     }
// 
//     void CreateAvx512()
//     {
//         push(rax);
//         push(rbx);
//         push(rcx);
// 
//         mov(rax, ptr[rdi]);
//         mov(rdx, ptr[rdi + 8]);
//         mov(rcx, ptr[rdi + 16]);
// 
//         // 初始化寄存器
//         vpxor(Ymm(1), Ymm(1), Ymm(1));
//         vpxor(Ymm(2), Ymm(2), Ymm(2));
//         vpxord(Zmm(1), Zmm(1), Zmm(1));
//         vpxord(Zmm(2), Zmm(2), Zmm(2));
//         vpxord(Zmm(3), Zmm(3), Zmm(3));
//         vpxord(Zmm(4), Zmm(4), Zmm(4));
// 
//         for (int i = 0; i < (FEAT_DIM >> 5); i++)
//         {
//             vlddqu(ymm1, ptr[rax + i * 32]);
//             vpmovsxbw(zmm1, ymm1);
//             vlddqu(ymm2, ptr[rdx + i * 32]);
//             vpmovsxbw(zmm2, ymm2);
//             // 16爲乘+
//             vpmaddwd(zmm3, zmm1, zmm2);
//             vpaddd(zmm4, zmm3, zmm4);
//         }
// 
//         // 數據保存
//         vmovupd(yword[rcx], zmm4);
// 
// 
// 
//         pop(rcx);
//         pop(rbx);
//         pop(rax);
//         ret();
//     }
// };
// 
// ProductJit     g_ProductJit;






/***************************************************************************************************
* 功  能: 人臉比對1v1接口
* 參  數: cmp_handle             - I    人臉比對句柄
*         fea1		      		 - I    人臉特徵1
*		  fea2			  		 - I    人臉特徵2
*         sim          	  		 - O    人臉相似度
* 返回值: 狀態碼
***************************************************************************************************/					  
HRESULT HIKFR_Compare_1vs1_CPU_v3(void  			*cmp_handle,
	unsigned char     *fea1,
	unsigned char     *fea2,
	float 			*sim)
{
	int     i;
	float fres;
	float   alpha;
	float   scale;
	float   bias;
    int *feat_1, *feat_version,*feat_2;
   

   // CHECK_ERROR((NULL == cmp_handle) || (NULL == fea1) || (NULL == fea2) || (NULL == sim), HIK_VCA_LIB_E_PTR_NULL);
 
    feat_1 = (int*)fea1;
    feat_2 = (int*)fea2;


    //version check
   feat_version = (int*)DFR_MODEL_VERSION_TAG;
    CHECK_ERROR((feat_1[0] != feat_version[0]) || (feat_1[1] != feat_version[1]) || (feat_1[2] != feat_version[2]), HIKFR_LIB_NOT_SUPPORT_VERSION);
    CHECK_ERROR((feat_version[0] != feat_2[0]) || (feat_version[1] != feat_2[1]) || (feat_version[2] != feat_2[2]), HIKFR_LIB_NOT_SUPPORT_VERSION);

	alpha = ALPHA;
	scale = SCALE;
	bias = BIAS;



    //////////////sse指令優化//////////

//     __m128i w_32_8 = { 0 };
//     __m128i w_128_src[2];
//     __m128i w_128_i[3];
//     int sum_avx_all[8];
//     int sum_avx = 0;
// 
//     char   *fea1_avx = (char *)(&fea1[FEAT_VER]);
//     char   *fea2_avx = (char *)(&fea2[FEAT_VER]);
// 
//     for (i = 0; i < (FEAT_DIM >> 3); i++)
//     {
//          w_128_src[0] = _mm_lddqu_si128((__m128i*)fea1_avx);
//          w_128_src[1] = _mm_lddqu_si128((__m128i*)fea2_avx);
//          w_128_i[0] = _mm_cvtepi8_epi16(w_128_src[0]);
//          w_128_i[1] = _mm_cvtepi8_epi16(w_128_src[1]);
//          w_128_i[2] = _mm_madd_epi16(w_128_i[0], w_128_i[1]);
//          w_32_8 = _mm_add_epi32(w_32_8, w_128_i[2]);
//         fea1_avx += 8;
//         fea2_avx += 8;
//     }
// 
//      _mm_storeu_si128((__m128i*)sum_avx_all, w_32_8);
// 
//    for (i = 0; i < 4; i++)
//     {
//         sum_avx += sum_avx_all[i];
//     }
// 
// 
//     fres = (float)sum_avx * alpha;

    ////////////////sse指令優化//////////

    //////////////AVX-2指令優化//////////

        __m256i w_32_8 = { 0 };
        __m128i w_128_i[2];
        __m256i w_256_i[3];
        int sum_avx_all[8];
        int sum_avx = 0;
        __m256i w_256_fea1[32];

        char   *fea1_avx = (char *)(&fea1[FEAT_VER]);
        char   *fea2_avx = (char *)(&fea2[FEAT_VER]);


//         _mm_prefetch(fea1_avx, _MM_HINT_NTA);
// 
//         for (i = 0; i < (FEAT_DIM >> 4); i++)
//         {
//             w_128_i[0] = _mm_lddqu_si128((__m128i*)fea1_avx);
//             w_256_fea1[i] = _mm256_cvtepi8_epi16(w_128_i[0]); 
//             fea1_avx += 16;
//         }
// 
//         // 加載2
//         _mm_prefetch(fea2_avx, _MM_HINT_NTA);
//         for (i = 0; i < (FEAT_DIM >> 4); i++)
//         {
//             w_128_i[1] = _mm_lddqu_si128((__m128i*)fea2_avx);
//             w_256_i[1] = _mm256_cvtepi8_epi16(w_128_i[1]);
//             fea2_avx += 16;
// 
//             w_256_i[2] = _mm256_madd_epi16(w_256_fea1[i], w_256_i[1]);
//             w_32_8 = _mm256_add_epi32(w_32_8, w_256_i[2]);
//         }

        
     

        for (i = 0; i < (FEAT_DIM >> 4); i++)
        {
            w_128_i[0] = _mm_lddqu_si128((__m128i*)fea1_avx);//加載128位數據
            w_128_i[1] = _mm_lddqu_si128((__m128i*)fea2_avx);
            w_256_i[0] = _mm256_cvtepi8_epi16(w_128_i[0]);//將每8位擴展爲16位
            w_256_i[1] = _mm256_cvtepi8_epi16(w_128_i[1]);
            w_256_i[2] = _mm256_madd_epi16(w_256_i[0], w_256_i[1]);//每16位進行相乘
            w_32_8 = _mm256_add_epi32(w_32_8, w_256_i[2]);//初始化256位向量,合併上一步的結果
            fea1_avx += 16;
            fea2_avx += 16;
        }

        _mm256_storeu_si256((__m256i*)sum_avx_all, w_32_8);//將結果移動到256位未對齊內存
        for (i = 0; i < 8; i++)
        {
            //sum += w_32_8.m256i_i32[i];
            sum_avx += sum_avx_all[i];
        }


       fres = (float)sum_avx * alpha;

    ////////////////AVX-2指令優化//////////


    ////////////////AVX-512指令優化//////////

//         __m512i w_32_8 = { 0 };
//         __m256i w_256_i[2];
//         __m512i w_512_i[3];
//         int sum_avx_all[16];
//         int sum_avx = 0;
// 
//         char   *fea1_avx = (char *)(&fea1[FEAT_VER]);
//         char   *fea2_avx = (char *)(&fea2[FEAT_VER]);
// 
//         for (i = 0; i < (FEAT_DIM >> 5); i++)
//         {
//             w_256_i[0] = _mm256_lddqu_si256((__m256i*)fea1_avx);
//             w_256_i[1] = _mm256_lddqu_si256((__m256i*)fea2_avx);
//             w_512_i[0] = _mm512_cvtepi8_epi16(w_256_i[0]);
//             w_512_i[1] = _mm512_cvtepi8_epi16(w_256_i[1]);
//             w_512_i[2] = _mm512_madd_epi16(w_512_i[0], w_512_i[1]);
//             w_32_8 = _mm512_add_epi32(w_32_8, w_512_i[2]);
//             fea1_avx += 32;
//             fea2_avx += 32;
//         }
// 
//         _mm512_storeu_si512((__m512i*)sum_avx_all, w_32_8);
//         for (i = 0; i < 16; i++)
//         {
//             //sum += w_32_8.m256i_i32[i];
//             sum_avx += sum_avx_all[i];
//         }
// 
// 
//        fres = (float)sum_avx * alpha;
    ////////////////AVX-512指令優化//////////



	*sim = fres;
	return HIK_VCA_LIB_S_OK;
}




/***************************************************************************************************
* 功  能: 人臉比對1v1接口
* 參  數: cmp_handle             - I    人臉比對句柄
*         fea1		      		 - I    人臉特徵1
*		  fea2			  		 - I    人臉特徵2
*         sim          	  		 - O    人臉相似度
* 返回值: 狀態碼
***************************************************************************************************/
HRESULT HIKFR_Compare_1vs1_CPU_mkl(void  			*cmp_handle,
    short             *pInt16_1,
    int                count1,
    short              *pInt16_2,
    int               count2,
    int               *result,
    float 			*sim)
{


    int     m = count1;
    int     n = count2;
    int     k = FEAT_DIM;
    float   alpha = 1.0;

    int     lda = k;
    int     oa = 0;
    int     ldb = k;
    int     ob = 0;
    float   beta = 0.0;
    int     ldc = count2;
    int     oc = 0;
	    cblas_gemm_s16s16s32(
        CblasRowMajor,
        CblasNoTrans,
        CblasTrans,
        CblasFixOffset,
        m,
        n,
        k,
        alpha,
        pInt16_1,
        lda,
        oa,
        pInt16_2,
        ldb,
        ob,
        beta,
        result,
        ldc,
        &oc);
/*
//for block*input
       cblas_gemm_s16s16s32(
        CblasRowMajor,
        CblasNoTrans,
        CblasTrans,
        CblasFixOffset,
        count2,
        count1,
        k,
        alpha,
        pInt16_2,
        ldb,
        ob,
        pInt16_1,
        k,
        oa,
        beta,
        result,
        count1,
        &oc);
		/**/
	

		
        

#ifdef SIM_REPROCESS
    
#endif // SIM_REPROCESS
    return HIK_VCA_LIB_S_OK;
}

void inline block8_1(__m256i * pInFeatureBuffer0,
                    __m256i * pInFeatureBuffer1,
                    __m256i * pInFeatureBuffer2,
                    __m256i * pInFeatureBuffer3,
                    __m256i * pInFeatureBuffer4,
                    __m256i * pInFeatureBuffer5,
                    __m256i * pInFeatureBuffer6,
                    __m256i * pInFeatureBuffer7,
                    char *inputB, int *outPut, int iRecord,int dbStep,int dbSize)
{
	#define VEC_LENGH 512

	  __m256i w_256_inFeature0,w_256_inFeature1,w_256_inFeature2,w_256_inFeature3;
	  __m256i w_256_inFeature4,w_256_inFeature5,w_256_inFeature6,w_256_inFeature7;
  
       __m256i w_256_dbFeature;
	   __m256i w_256_mul0,w_256_sum0,w_256_mul1,w_256_sum1,w_256_mul2,w_256_sum2,w_256_mul3,w_256_sum3;
    	 __m256i w_256_mul4,w_256_sum4,w_256_mul5,w_256_sum5,w_256_mul6,w_256_sum6,w_256_mul7,w_256_sum7;
 
	   __m128i w_128_dbFeauture;


	   __m128i * pFeatureDataBase = (__m128i *) (inputB+iRecord * dbStep);
 
		w_256_sum0 = _mm256_xor_si256(w_256_sum0, w_256_sum0); //_mm256_setzero_si256();
		w_256_sum1 = _mm256_xor_si256(w_256_sum1, w_256_sum1); //_mm256_setzero_si256();
    	w_256_sum2 = _mm256_xor_si256(w_256_sum2, w_256_sum2); //_mm256_setzero_si256();
		w_256_sum3 = _mm256_xor_si256(w_256_sum3, w_256_sum3); //_mm256_setzero_si256();
		w_256_sum4 = _mm256_xor_si256(w_256_sum4, w_256_sum4); //_mm256_setzero_si256();
		w_256_sum5 = _mm256_xor_si256(w_256_sum5, w_256_sum5); //_mm256_setzero_si256();
    	w_256_sum6 = _mm256_xor_si256(w_256_sum6, w_256_sum6); //_mm256_setzero_si256();
		w_256_sum7 = _mm256_xor_si256(w_256_sum7, w_256_sum7); //_mm256_setzero_si256();

          
		for (int i = 0; i < VEC_LENGH / 16; i++) {
			w_256_inFeature0 = _mm256_load_si256(pInFeatureBuffer0 + i);
			w_256_inFeature1 = _mm256_load_si256(pInFeatureBuffer1 + i);
			w_256_inFeature2 = _mm256_load_si256(pInFeatureBuffer2 + i);
			w_256_inFeature3 = _mm256_load_si256(pInFeatureBuffer3 + i);
	    	w_256_inFeature4 = _mm256_load_si256(pInFeatureBuffer4 + i);
			w_256_inFeature5 = _mm256_load_si256(pInFeatureBuffer5 + i);
			w_256_inFeature6 = _mm256_load_si256(pInFeatureBuffer6 + i);
			w_256_inFeature7 = _mm256_load_si256(pInFeatureBuffer7 + i);

			
			w_128_dbFeauture = _mm_lddqu_si128(pFeatureDataBase + i);
       
			w_256_dbFeature = _mm256_cvtepi8_epi16(w_128_dbFeauture);
  
			w_256_mul0 = _mm256_madd_epi16(w_256_inFeature0, w_256_dbFeature);
			w_256_mul1 = _mm256_madd_epi16(w_256_inFeature1, w_256_dbFeature);
			w_256_mul2 = _mm256_madd_epi16(w_256_inFeature2, w_256_dbFeature);
			w_256_mul3 = _mm256_madd_epi16(w_256_inFeature3, w_256_dbFeature);
			w_256_mul4 = _mm256_madd_epi16(w_256_inFeature4, w_256_dbFeature);
			w_256_mul5 = _mm256_madd_epi16(w_256_inFeature5, w_256_dbFeature);
			w_256_mul6 = _mm256_madd_epi16(w_256_inFeature6, w_256_dbFeature);
			w_256_mul7 = _mm256_madd_epi16(w_256_inFeature7, w_256_dbFeature);


			w_256_sum0 = _mm256_add_epi32(w_256_sum0, w_256_mul0);
			w_256_sum1 = _mm256_add_epi32(w_256_sum1, w_256_mul1);
			w_256_sum2 = _mm256_add_epi32(w_256_sum2, w_256_mul2);
			w_256_sum3 = _mm256_add_epi32(w_256_sum3, w_256_mul3);
			w_256_sum4 = _mm256_add_epi32(w_256_sum4, w_256_mul4);
			w_256_sum5 = _mm256_add_epi32(w_256_sum5, w_256_mul5);
			w_256_sum6 = _mm256_add_epi32(w_256_sum6, w_256_mul6);
			w_256_sum7 = _mm256_add_epi32(w_256_sum7, w_256_mul7);


		
		}

		__m256i hsum = _mm256_hadd_epi32(w_256_sum0, w_256_sum0);
		hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
		hsum = _mm256_hadd_epi32(hsum, hsum);
		_mm_store_ss((float *)outPut + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
	             
        hsum = _mm256_hadd_epi32(w_256_sum1, w_256_sum1);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum2, w_256_sum2);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 2*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum3, w_256_sum3);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 3*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));		
		
		
		hsum = _mm256_hadd_epi32(w_256_sum4, w_256_sum4);
		hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
		hsum = _mm256_hadd_epi32(hsum, hsum);
		_mm_store_ss((float *)outPut + 4*dbSize+iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
	             
        hsum = _mm256_hadd_epi32(w_256_sum5, w_256_sum5);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 5*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum6, w_256_sum6);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 6*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum7, w_256_sum7);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 7*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));		
		
	
}


void avx2_gemm_s8s8s32_16batch(char *inputA, char *inputB, int *outPut, int dbStep, int dbSize) 
{
		Ipp16s * inFeatureBuffer;  //, *inFeatureBuffer1;
	
	__m256i * pInFeatureBuffer0,* pInFeatureBuffer1,* pInFeatureBuffer2,* pInFeatureBuffer3,
	        * pInFeatureBuffer4,* pInFeatureBuffer5, * pInFeatureBuffer6, * pInFeatureBuffer7;
	
	inFeatureBuffer = ippsMalloc_16s(VEC_LENGH*16);
	ippsConvert_8s16s((Ipp8s*)inputA, inFeatureBuffer, VEC_LENGH*16);
	
	int iRecord = 0; 
	while (iRecord < dbSize) {
		
	 pInFeatureBuffer0 = (__m256i *) inFeatureBuffer;
	 pInFeatureBuffer1 = (__m256i *) (inFeatureBuffer+VEC_LENGH);
	 pInFeatureBuffer2 = (__m256i *) (inFeatureBuffer+2*VEC_LENGH);
	 pInFeatureBuffer3 = (__m256i *) (inFeatureBuffer+3*VEC_LENGH);
	 pInFeatureBuffer4 = (__m256i *) (inFeatureBuffer+4*VEC_LENGH);
	 pInFeatureBuffer5= (__m256i *)  (inFeatureBuffer+5*VEC_LENGH);
	 pInFeatureBuffer6 = (__m256i *) (inFeatureBuffer+6*VEC_LENGH);
	 pInFeatureBuffer7 = (__m256i *) (inFeatureBuffer+7*VEC_LENGH);
	 
	 
	 block8_1( pInFeatureBuffer0, pInFeatureBuffer1,pInFeatureBuffer2,  pInFeatureBuffer3,
               pInFeatureBuffer4, pInFeatureBuffer5, pInFeatureBuffer6, pInFeatureBuffer7,
               inputB, outPut, iRecord,dbStep,dbSize);

     pInFeatureBuffer0 = (__m256i *) (inFeatureBuffer+8*VEC_LENGH);
	 pInFeatureBuffer1 = (__m256i *) (inFeatureBuffer+9*VEC_LENGH);
	 pInFeatureBuffer2 = (__m256i *) (inFeatureBuffer+10*VEC_LENGH);
	 pInFeatureBuffer3 = (__m256i *) (inFeatureBuffer+11*VEC_LENGH);
	 pInFeatureBuffer4 = (__m256i *) (inFeatureBuffer+12*VEC_LENGH);
	 pInFeatureBuffer5= (__m256i *)  (inFeatureBuffer+13*VEC_LENGH);
	 pInFeatureBuffer6 = (__m256i *) (inFeatureBuffer+14*VEC_LENGH);
	 pInFeatureBuffer7 = (__m256i *) (inFeatureBuffer+15*VEC_LENGH);

	 block8_1( pInFeatureBuffer0, pInFeatureBuffer1,pInFeatureBuffer2,  pInFeatureBuffer3,
               pInFeatureBuffer4, pInFeatureBuffer5, pInFeatureBuffer6, pInFeatureBuffer7,
               inputB, outPut + 8*dbSize, iRecord,dbStep,dbSize);
			   
	iRecord++;
	 
	}
}


void avx2_gemm_s8s8s32_nbatch(char *inputA, char *inputB, int *outPut, int dbStep, int dbSize, int nInput) 
{

	#define VEC_LENGH 512

	Ipp16s * inFeatureBuffer;  //, *inFeatureBuffer1;
	
	__m256i * pInFeatureBuffer0,* pInFeatureBuffer1,* pInFeatureBuffer2,* pInFeatureBuffer3,
	        * pInFeatureBuffer4,* pInFeatureBuffer5, * pInFeatureBuffer6, * pInFeatureBuffer7;
	
	inFeatureBuffer = ippsMalloc_16s(VEC_LENGH*nInput);
	ippsConvert_8s16s((Ipp8s*)inputA, inFeatureBuffer, VEC_LENGH*nInput);
	
	int iRecord = 0; 
	while (iRecord < dbSize) {
		
	   for(int j=0;j< nInput; j+=8){
	    pInFeatureBuffer0 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH);
	    pInFeatureBuffer1 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+VEC_LENGH);
	    pInFeatureBuffer2 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+2*VEC_LENGH);
	    pInFeatureBuffer3 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+3*VEC_LENGH);
	    pInFeatureBuffer4 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+4*VEC_LENGH);
	    pInFeatureBuffer5= (__m256i *)  (inFeatureBuffer+j*VEC_LENGH+5*VEC_LENGH);
	    pInFeatureBuffer6 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+6*VEC_LENGH);
	    pInFeatureBuffer7 = (__m256i *) (inFeatureBuffer+j*VEC_LENGH+7*VEC_LENGH);
	 
	   
	    block8_1( pInFeatureBuffer0, pInFeatureBuffer1,pInFeatureBuffer2,  pInFeatureBuffer3,
               pInFeatureBuffer4, pInFeatureBuffer5, pInFeatureBuffer6, pInFeatureBuffer7,
               inputB, outPut, iRecord,dbStep,dbSize); 

	   }	   
	iRecord++;
	 
	}
}



void avx2_gemm_s8s8s32_4batch(char *inputA, char *inputB, int *outPut, int dbStep, int dbSize) 
{

#define VEC_LENGH 512

	Ipp16s * inFeatureBuffer;  //, *inFeatureBuffer1;
	__m256i w_256_inFeature0,w_256_inFeature1,w_256_inFeature2,w_256_inFeature3;
     __m256i w_256_dbFeature;
	 __m256i w_256_mul0,w_256_sum0,w_256_mul1,w_256_sum1,w_256_mul2,w_256_sum2,w_256_mul3,w_256_sum3;
  
	__m128i w_128_dbFeauture;
     //__m128i w_128_dbFeauture1;

	inFeatureBuffer = ippsMalloc_16s(VEC_LENGH*4);
	ippsConvert_8s16s((Ipp8s*)inputA, inFeatureBuffer, VEC_LENGH*4);

	__m256i * pInFeatureBuffer0 = (__m256i *) inFeatureBuffer;
	__m256i * pInFeatureBuffer1 = (__m256i *) (inFeatureBuffer+VEC_LENGH);
	__m256i * pInFeatureBuffer2 = (__m256i *) (inFeatureBuffer+2*VEC_LENGH);
	__m256i * pInFeatureBuffer3 = (__m256i *) (inFeatureBuffer+3*VEC_LENGH);

	int iRecord = 0; 

	while (iRecord < dbSize) {
		__m128i * pFeatureDataBase = (__m128i *) (inputB + iRecord * dbStep);

		w_256_sum0 = _mm256_xor_si256(w_256_sum0, w_256_sum0); 
		w_256_sum1 = _mm256_xor_si256(w_256_sum1, w_256_sum1); 
    	w_256_sum2 = _mm256_xor_si256(w_256_sum2, w_256_sum2); 
		w_256_sum3 = _mm256_xor_si256(w_256_sum3, w_256_sum3); 

          
		for (int i = 0; i < VEC_LENGH / 16; i++) {
			w_256_inFeature0 = _mm256_load_si256(pInFeatureBuffer0 + i);
			w_256_inFeature1 = _mm256_load_si256(pInFeatureBuffer1 + i);
			w_256_inFeature2 = _mm256_load_si256(pInFeatureBuffer2 + i);
			w_256_inFeature3 = _mm256_load_si256(pInFeatureBuffer3 + i);

			
			w_128_dbFeauture = _mm_lddqu_si128(pFeatureDataBase + i);
       
			w_256_dbFeature = _mm256_cvtepi8_epi16(w_128_dbFeauture);
  
			w_256_mul0 = _mm256_madd_epi16(w_256_inFeature0, w_256_dbFeature);
			w_256_mul1 = _mm256_madd_epi16(w_256_inFeature1, w_256_dbFeature);
			w_256_mul2 = _mm256_madd_epi16(w_256_inFeature2, w_256_dbFeature);
			w_256_mul3 = _mm256_madd_epi16(w_256_inFeature3, w_256_dbFeature);


			w_256_sum0 = _mm256_add_epi32(w_256_sum0, w_256_mul0);
			w_256_sum1 = _mm256_add_epi32(w_256_sum1, w_256_mul1);
			w_256_sum2 = _mm256_add_epi32(w_256_sum2, w_256_mul2);
			w_256_sum3 = _mm256_add_epi32(w_256_sum3, w_256_mul3);


		
		}

		__m256i hsum = _mm256_hadd_epi32(w_256_sum0, w_256_sum0);
		hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
		hsum = _mm256_hadd_epi32(hsum, hsum);
		_mm_store_ss((float *)outPut + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
	             
        hsum = _mm256_hadd_epi32(w_256_sum1, w_256_sum1);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum2, w_256_sum2);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 2*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));

        hsum = _mm256_hadd_epi32(w_256_sum3, w_256_sum3);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);
        _mm_store_ss((float *)outPut + 3*dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));		
		
        iRecord++; 
 
	}

	ippsFree(inFeatureBuffer);
}



void avx2_gemm_s8s8s32_2batch(char *inputA, char *inputB, int *outPut, int dbStep, int dbSize) 
{

#define VEC_LENGH 512

	Ipp16s * inFeatureBuffer;  //, *inFeatureBuffer1;
	__m256i w_256_inFeature0,w_256_inFeature1;
     __m256i w_256_dbFeature;
	 __m256i w_256_mul0,w_256_sum0,w_256_mul1,w_256_sum1;
  
	__m128i w_128_dbFeauture;

	inFeatureBuffer = ippsMalloc_16s(VEC_LENGH*2);
	ippsConvert_8s16s((Ipp8s*)inputA, inFeatureBuffer, VEC_LENGH*2);

	__m256i * pInFeatureBuffer0 = (__m256i *) inFeatureBuffer;
	__m256i * pInFeatureBuffer1 = (__m256i *) (inFeatureBuffer+VEC_LENGH);
	int iRecord = 0; 

	while (iRecord < dbSize) {
		__m128i * pFeatureDataBase = (__m128i *) (inputB + iRecord * dbStep);
		w_256_sum0 = _mm256_xor_si256(w_256_sum0, w_256_sum0); 
		w_256_sum1 = _mm256_xor_si256(w_256_sum1, w_256_sum1); 


          
		for (int i = 0; i < VEC_LENGH / 16; i++) {
			w_256_inFeature0 = _mm256_load_si256(pInFeatureBuffer0 + i);
			w_256_inFeature1 = _mm256_load_si256(pInFeatureBuffer1 + i);

			w_128_dbFeauture = _mm_lddqu_si128(pFeatureDataBase + i);
       
			w_256_dbFeature = _mm256_cvtepi8_epi16(w_128_dbFeauture);
  
			w_256_mul0 = _mm256_madd_epi16(w_256_inFeature0, w_256_dbFeature);
			w_256_mul1 = _mm256_madd_epi16(w_256_inFeature1, w_256_dbFeature);

			w_256_sum0 = _mm256_add_epi32(w_256_sum0, w_256_mul0);
			w_256_sum1 = _mm256_add_epi32(w_256_sum1, w_256_mul1);


		
		}

		__m256i hsum = _mm256_hadd_epi32(w_256_sum0, w_256_sum0);
		hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
		hsum = _mm256_hadd_epi32(hsum, hsum);
		_mm_store_ss((float *)outPut + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
	             
        hsum = _mm256_hadd_epi32(w_256_sum1, w_256_sum1);
        hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
        hsum = _mm256_hadd_epi32(hsum, hsum);

        _mm_store_ss((float *)outPut + dbSize + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
        iRecord++; 
 
	}

	ippsFree(inFeatureBuffer);
}


void avx2_gemm_s8s8s32_1batch(char *inputA, char *inputB, int *outPut, int dbStep, int dbSize) 
{

#define VEC_LENGH 512

	Ipp16s * inFeatureBuffer;
	__m256i w_256_inFeature;
    __m256i w_256_dbFeature,w_256_mul,w_256_sum;

	__m128i w_128_dbFeauture;

	inFeatureBuffer = ippsMalloc_16s(VEC_LENGH*2);
	ippsConvert_8s16s((Ipp8s*)inputA, inFeatureBuffer, VEC_LENGH*2);
	
	__m256i * pInFeatureBuffer = (__m256i *) inFeatureBuffer;


	int iRecord = 0; 

	while (iRecord < dbSize) {

		__m128i * pFeatureDataBase = (__m128i *) (inputB + iRecord * dbStep);

		w_256_sum = _mm256_xor_si256(w_256_sum, w_256_sum); 
          
		for (int i = 0; i < VEC_LENGH / 16; i++) {

			w_256_inFeature = _mm256_load_si256(pInFeatureBuffer + i);

			w_128_dbFeauture = _mm_lddqu_si128(pFeatureDataBase + i);
              //          w_128_dbFeauture =  _mm_stream_load_si128(pFeatureDataBase + i); 
                        // _mm_prefetch((void *)(pFeatureDataBase + i+2),_MM_HINT_T1);
       
			w_256_dbFeature = _mm256_cvtepi8_epi16(w_128_dbFeauture);
  
			w_256_mul = _mm256_madd_epi16(w_256_inFeature, w_256_dbFeature);

			w_256_sum = _mm256_add_epi32(w_256_sum, w_256_mul);
		
		}

		__m256i hsum = _mm256_hadd_epi32(w_256_sum, w_256_sum);
		hsum = _mm256_add_epi32(hsum, _mm256_permute2f128_si256(hsum, hsum, 0x1));
		hsum = _mm256_hadd_epi32(hsum, hsum);

		_mm_store_ss((float *)outPut + iRecord, _mm256_castps256_ps128(_mm256_castsi256_ps(hsum)));
		iRecord++;
                
 
	}

	ippsFree(inFeatureBuffer);
}

typedef struct _COMPARE_PARAM_
{
    int     nNum;
    __m256i tmp256[MAX_FIX_NUM][32];
}COMPARE_PARAM;

COMPARE_PARAM stParam[58];

HRESULT HIKFR_COmpare_Create(void **handle)
{
    COMPARE_PARAM   *pstParam = (COMPARE_PARAM*)malloc(sizeof(COMPARE_PARAM));

    if (pstParam == NULL)
    {
        printf("3333333333\n");
        return -1;
    }

    *handle = (void*)pstParam;

    printf("%x\n", pstParam);
    return 0;
}

HRESULT HIKFR_COmpare_Set(void *handle, int index, int nNum, unsigned char **fea)
{
    COMPARE_PARAM   *pstParam = &(stParam[index]);

    int     i;
    __m128i w_128_i;
    char   *fea1_avx;
    char    *tmp;
    __m256i tmp256;
    for (int j = 0; j < nNum; j++)
    {
        tmp = (char*)fea[j];
        fea1_avx = (char *)(&tmp[FEAT_VER]);

        for (i = 0; i < (FEAT_DIM >> 4); i++)
        {
            w_128_i = _mm_lddqu_si128((__m128i*)fea1_avx);
            tmp256 = _mm256_cvtepi8_epi16(w_128_i);
            pstParam->tmp256[j][i] = tmp256;

            fea1_avx += 16;
        }
    }

    pstParam->nNum = nNum;
    return 0;
}

HRESULT HIKFR_Compare_nvs1_CPU_v3(void  			*dfr_handle,
    int index,
    unsigned char     *fea2,
    float 			*sim)
{
    int     i;
    float fres;
    float   alpha;
    float   scale;
    float   bias;
    int  *feat_version, *feat_2;

    COMPARE_PARAM   *pstParam = &(stParam[index]);



    feat_2 = (int*)fea2;


    //version check
    feat_version = (int*)DFR_MODEL_VERSION_TAG;
    CHECK_ERROR((feat_version[0] != feat_2[0]) || (feat_version[1] != feat_2[1]) || (feat_version[2] != feat_2[2]), HIKFR_LIB_NOT_SUPPORT_VERSION);

    alpha = ALPHA;
    scale = SCALE;
    bias = BIAS;

  
    ////////////////sse指令優化//////////

    //////////////AVX-2指令優化//////////
    for (int j = 0; j < pstParam->nNum; j++)
    {
        __m256i w_32_8 = { 0 };
        __m128i w_128_i[2];
        __m256i w_256_i[3];
        int sum_avx_all[8];
        int sum_avx = 0;


        char   *fea2_avx = (char *)(&fea2[FEAT_VER]);


        for (i = 0; i < (FEAT_DIM >> 4); i++)
        {

            w_128_i[1] = _mm_lddqu_si128((__m128i*)fea2_avx);
            w_256_i[1] = _mm256_cvtepi8_epi16(w_128_i[1]);

            w_256_i[2] = _mm256_madd_epi16(pstParam->tmp256[j][i], w_256_i[1]);
            w_32_8 = _mm256_add_epi32(w_32_8, w_256_i[2]);
       
            fea2_avx += 16;
        }

        _mm256_storeu_si256((__m256i*)sum_avx_all, w_32_8);
        for (i = 0; i < 8; i++)
        {
            sum_avx += sum_avx_all[i];
        }


        fres = (float)sum_avx * alpha;

  


        sim[j] = fres;
    }

   
    return HIK_VCA_LIB_S_OK;
}

//topK
/** @fn       UpToDown
 *  @brief    構建最小堆
 *  @param    nRootPos    [in]  - 根節點下標
 *  @param    nResultCnt  [in]  - 排序結果大小
 *  @param    pResult     [out] - 排序結果數組
 *  @return   int
 */
void UpToDown(int nRootPos, int nResultCnt, void* pResult)
{
    int nPos;
    HEAP_INFO  stTmpInfo;
    HEAP_INFO* pstResult = (HEAP_INFO*)pResult;

    // 左孩子(存在的話)
    int nLeft = 2 * nRootPos;
    // 右孩子(存在的話)
    int nRight = nLeft + 1;

    // 無孩子節點,調整結束
    if (nLeft >= nResultCnt)
    {
        return;
    }
    else
    {
        // 只有左孩子
        if (nRight >= nResultCnt)
        {
            nPos = nLeft;
        }
        else
        {
            nPos = pstResult[nLeft].fSimilarity > pstResult[nRight].fSimilarity ? nRight : nLeft;
        }

        // pos保存在子孩子中,數值較小者的位置
        if (pstResult[nRootPos].fSimilarity > pstResult[nPos].fSimilarity)   //compare with lower children, lower up
        {
            stTmpInfo           = pstResult[nRootPos];
            pstResult[nRootPos] = pstResult[nPos];
            pstResult[nPos]     = stTmpInfo;

            UpToDown(nPos, nResultCnt, (void *)pstResult);
        }
    }
}

/** @fn       HeapSortResult
 *  @brief    最小堆排序
 *  @param    pSimilarityLocal       [in]   -  比對後相似度結果(CPU內存)
 *  @param    pIndexFlg              [in]   -  模型有效標識
 *  @param    pHeapInfo              [in]   -  最小堆數組
 *  @param    nHeapSize              [in]   -  數組大小
 *  @return   錯誤碼
 *  @note     
 */
int HeapSortResult(unsigned int startIdx, unsigned int len, float* pSimilarityLocal, unsigned char* pIndexFlg, HEAP_INFO* pHeapInfo, unsigned int& nHeapSize)
{
    if (NULL == pSimilarityLocal || NULL == pHeapInfo)
    {
        return -1;
    }

    // TOP_K 排序
    //memset(pHeapInfo, 0, nHeapSize * sizeof(HEAP_INFO));
//pHeapInfo[0] for flag, fill 1,add 1; pHeapInfo[1]-pHeapInfo[1000] to store, init to 0;
	int nTran = 0;  
	if(pHeapInfo[0].nIndex < nHeapSize-1) //fill + potential heapify+ potential update
	{
		for( ; nTran < len; nTran++)
		{
			if (1 == pIndexFlg[nTran])
            {
				pHeapInfo[0].nIndex++;
                pHeapInfo[pHeapInfo[0].nIndex].fSimilarity = pSimilarityLocal[nTran];
                pHeapInfo[pHeapInfo[0].nIndex].nIndex      = startIdx+nTran;
				if(pHeapInfo[0].nIndex >= nHeapSize -1)
				{
					 // heapify 1000
                    for (unsigned int i = nHeapSize / 2; i >= 1; i--)           //build heap of size 1000
   				    {
      			        UpToDown(i, nHeapSize, (void*)pHeapInfo);
   					}
					break;
				}
            }
			
		}
		//update the rest
		 __m512 vs    = _mm512_set1_ps(pHeapInfo[1].fSimilarity);
		for ( ; nTran < len-16; nTran+=16 )
    	{
			
			__m512 v0    = _mm512_loadu_ps(&pSimilarityLocal[nTran]);   //Latency:1, Throughput:0.5
	
			__mmask16 res0  = _mm512_cmp_ps_mask(v0, vs, _CMP_GT_OQ);          //Latency:3, Throughput:1
        
			//int mask0 = _mm256_movemask_ps(res0);
		
			if(res0!=0)
			{
		  	for(int i = 0; i<16; i++)
		  	{
				if (pSimilarityLocal[nTran+i] > pHeapInfo[1].fSimilarity)   //max k  [0] is min of 1000max
            	{
                	pHeapInfo[1].fSimilarity = pSimilarityLocal[nTran+i];
                	pHeapInfo[1].nIndex      = startIdx+nTran+i;
                	UpToDown(1, nHeapSize, (void*)pHeapInfo);
            	}
		  	}
			vs    = _mm512_set1_ps(pHeapInfo[1].fSimilarity);
			}
			
		}
		for( ; nTran < len; nTran++)
		{
      		if (pSimilarityLocal[nTran] > pHeapInfo[1].fSimilarity)   //max k  [1] is min of 1000max
        	{
            	pHeapInfo[1].fSimilarity = pSimilarityLocal[nTran];
            	pHeapInfo[1].nIndex      = startIdx+nTran;
            	UpToDown(1, nHeapSize, (void*)pHeapInfo);
        	}
        	else
       	 	{
            	continue;
        	}
			
		}
	}
	else //update
	{
		
		__m512 vs    = _mm512_set1_ps(pHeapInfo[1].fSimilarity);
		
		for ( ; nTran < len-16; nTran+=16)
    	{
			__m512 v0    = _mm512_loadu_ps(&pSimilarityLocal[nTran]);   //Latency:1, Throughput:0.5
	
			//__m256 vs    = _mm256_set1_ps(pHeapInfo[1].fSimilarity);
			__mmask16 res0  = _mm512_cmp_ps_mask(v0, vs, _CMP_GT_OQ);          //Latency:3, Throughput:1
        
			//int mask0 = _mm256_movemask_ps(res0);
		
			if(res0!=0)
			{
		  	for(int i = 0; i<16; i++)
		  	{
				if (pSimilarityLocal[nTran+i] > pHeapInfo[1].fSimilarity)   //max k  [0] is min of 1000max
            	{
                	pHeapInfo[1].fSimilarity = pSimilarityLocal[nTran+i];
                	pHeapInfo[1].nIndex      = startIdx+nTran+i;
                	UpToDown(1, nHeapSize, (void*)pHeapInfo);
            	}
		  	}
			vs    = _mm512_set1_ps(pHeapInfo[1].fSimilarity);
			}
		}
			/**/
		
		for( ; nTran < len; nTran++)
		{
			if (pSimilarityLocal[nTran] > pHeapInfo[1].fSimilarity)   //max k  [1] is min of 1000max
        	{
            	pHeapInfo[1].fSimilarity = pSimilarityLocal[nTran];
            	pHeapInfo[1].nIndex      = startIdx+nTran;
            	UpToDown(1, nHeapSize, (void*)pHeapInfo);
        	}
        	else
       		{
            	continue;
        	}
	    }
	}
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章