caffe表情識別(一):準備數據fer2013

  • 數據集: Kaggle fer2013(
  • 下載fer2013之後,解壓出的是csv格式的數據,我們需要先將數據轉換成圖片。

    這裏寫圖片描述

    這裏寫圖片描述

    step 1: 從fer2013.csv中提取出訓練集、驗證集和測試集

    這裏寫圖片描述

    convert_fer2013.py:

    # -*- coding: utf-8 -*-
    import csv
    import os
    
    database_path = r'F:\Datasets\fer2013'
    datasets_path = r'.\datasets'
    csv_file = os.path.join(database_path, 'fer2013.csv')
    train_csv = os.path.join(datasets_path, 'train.csv')
    val_csv = os.path.join(datasets_path, 'val.csv')
    test_csv = os.path.join(datasets_path, 'test.csv')
    
    
    with open(csv_file) as f:
        csvr = csv.reader(f)
        header = next(csvr)
        rows = [row for row in csvr]
    
        trn = [row[:-1] for row in rows if row[-1] == 'Training']
        csv.writer(open(train_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + trn)
        print(len(trn))
    
        val = [row[:-1] for row in rows if row[-1] == 'PublicTest']
        csv.writer(open(val_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + val)
        print(len(val))
    
        tst = [row[:-1] for row in rows if row[-1] == 'PrivateTest']
        csv.writer(open(test_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + tst)
        print(len(tst))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    注意:在Windows平臺中,需要在csv.writer()中加上lineterminator='\n'不然在生存的csv文件中,每行之間會有空行,影響後續操作。在Linux平臺中不需要這樣做。

    step 2: 將csv中的數據轉化成圖片

    這裏寫圖片描述

    convert_csv2gray:

    # -*- coding: utf-8 -*-
    import csv
    import os
    from PIL import Image
    import numpy as np
    
    
    datasets_path = r'.\datasets'
    train_csv = os.path.join(datasets_path, 'train.csv')
    val_csv = os.path.join(datasets_path, 'val.csv')
    test_csv = os.path.join(datasets_path, 'test.csv')
    
    train_set = os.path.join(datasets_path, 'train')
    val_set = os.path.join(datasets_path, 'val')
    test_set = os.path.join(datasets_path, 'test')
    
    for save_path, csv_file in [(train_set, train_csv), (val_set, val_csv), (test_set, test_csv)]:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    
        num = 1
        with open(csv_file) as f:
            csvr = csv.reader(f)
            header = next(csvr)
            for i, (label, pixel) in enumerate(csvr):
                pixel = np.asarray([float(p) for p in pixel.split()]).reshape(48, 48)
                subfolder = os.path.join(save_path, label)
                if not os.path.exists(subfolder):
                    os.makedirs(subfolder)
                im = Image.fromarray(pixel).convert('L')
                image_name = os.path.join(subfolder, '{:05d}.jpg'.format(i))
                print(image_name)
                im.save(image_name)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    生成的數據集目錄結構如下:

    這裏寫圖片描述

  • 參考文獻:
  • 第一篇文獻中有網絡結構圖,但根據我做實驗的情況來看,這篇論文水分較大,達不到論文中所說的分類精度。第二篇內容比第一篇詳細很多,很值得參考。

    訓練結果:
    這裏寫圖片描述

    發表評論
    所有評論
    還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
    相關文章