【darknet學習筆記】訓練和預測圖片分類,加載數據比較

修改darknet源代碼,使其能夠直接訓練二進制圖像數據

1.訓練加載數據load_data()和預測加載數據load_data_in_thread()比較

訓練分類圖像函數train_classifier()中加載圖片數據,使用load_data()

darknet.c文件中load_data()代碼:

pthread_t load_data(load_args args)
{
    pthread_t thread;
    struct load_args* ptr = (load_args*)calloc(1, sizeof(struct load_args));
    *ptr = args;
    if(pthread_create(&thread, 0, load_threads, ptr)) error("Thread creation failed");
    return thread;
}

預測圖像分類函數test_classifier()函數中,加載圖片數據使用load_data_in_thread()

load_data_in_thread()也在darknet.c中實現

pthread_t load_data_in_thread(load_args args)
{
    pthread_t thread;
    struct load_args* ptr = (load_args*)calloc(1, sizeof(struct load_args));
    *ptr = args;
    if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
    return thread;
}

2.訓練train加載數據線程load_threads()和預測test加載數據線程load_thread()

load_thread():

void *load_thread(void *ptr)
{
    //srand(time(0));
    //printf("Loading data: %d\n", random_gen());
    load_args a = *(struct load_args*)ptr;
    if(a.exposure == 0) a.exposure = 1;
    if(a.saturation == 0) a.saturation = 1;
    if(a.aspect == 0) a.aspect = 1;

    if (a.type == OLD_CLASSIFICATION_DATA){
        *a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
    } else if (a.type == OLD_CLASSIFICATION_MEMORY){
		//*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
		*a.d = load_data_memory(a.imgdata, a.n, a.classes, a.w, a.h, a.out_w, a.out_h);
    }else if (a.type == CLASSIFICATION_DATA){
       // *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.flip, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    } else if (a.type == SUPER_DATA){
       // *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
    } else if (a.type == WRITING_DATA){
       // *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
    } else if (a.type == REGION_DATA){
       // *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
    } else if (a.type == DETECTION_DATA){
       // *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.mixup, a.jitter,
        //    a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.letter_box, a.show_imgs);
    } else if (a.type == SWAG_DATA){
       // *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
    } else if (a.type == COMPARE_DATA){
      //  *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
    } else if (a.type == IMAGE_DATA){
    //    *(a.im) = load_image(a.path, 0, 0, a.c);
        *(a.resized) = resize_image(*(a.im), a.w, a.h);
    }else if (a.type == LETTERBOX_DATA) {
      //  *(a.im) = load_image(a.path, 0, 0, a.c);
       // *(a.resized) = letterbox_image(*(a.im), a.w, a.h);
    } else if (a.type == TAG_DATA){
       // *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.flip, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    }
    free(ptr);
    return 0;
}

此代碼是我修改過的預測動態庫代碼,註釋部分在源代碼中都有用,只是在我的動態庫中不需要。

預測時新加了內存圖片類型:

 } else if (a.type == OLD_CLASSIFICATION_MEMORY){
		//*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
		*a.d = load_data_memory(a.imgdata, a.n, a.classes, a.w, a.h, a.out_w, a.out_h);

load_threads():

void *load_threads(void *ptr)
{
    //srand(time(0));
    int i;
    load_args args = *(load_args *)ptr;
    if (args.threads == 0) args.threads = 1;
    data *out = args.d;
    int total = args.n;
    free(ptr);
    data* buffers = (data*)calloc(args.threads, sizeof(data));
    pthread_t* threads = (pthread_t*)calloc(args.threads, sizeof(pthread_t));
    for(i = 0; i < args.threads; ++i){
        args.d = buffers + i;
        args.n = (i+1) * total/args.threads - i * total/args.threads;
        threads[i] = load_data_in_thread(args);
    }
    for(i = 0; i < args.threads; ++i){
        pthread_join(threads[i], 0);
    }
    *out = concat_datas(buffers, args.threads);
    out->shallow = 0;
    for(i = 0; i < args.threads; ++i){
        buffers[i].shallow = 1;
        free_data(buffers[i]);
    }
    free(buffers);
    free(threads);
    return 0;
}

load_threads代碼的本質也是在多次調用load_data_in_thread(),其內部本質上也是通過load_thread()加載數據,所以我們可以在load_thread添加分支,使darknet也可以支持直接讀取二進制文件。

args.threads:

darknet中train_classifier()函數直接設置args.threads=32

 args.threads = 32;

 

發佈了373 篇原創文章 · 獲贊 151 · 訪問量 33萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章