下載fer2013之後,解壓出的是csv格式的數據,我們需要先將數據轉換成圖片。
step 1: 從fer2013.csv中提取出訓練集、驗證集和測試集
convert_fer2013.py:
# -*- coding: utf-8 -*-
import csv
import os
database_path = r'F:\Datasets\fer2013'
datasets_path = r'.\datasets'
csv_file = os.path.join(database_path, 'fer2013.csv')
train_csv = os.path.join(datasets_path, 'train.csv')
val_csv = os.path.join(datasets_path, 'val.csv')
test_csv = os.path.join(datasets_path, 'test.csv')
with open(csv_file) as f:
csvr = csv.reader(f)
header = next(csvr)
rows = [row for row in csvr]
trn = [row[:-1] for row in rows if row[-1] == 'Training']
csv.writer(open(train_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + trn)
print(len(trn))
val = [row[:-1] for row in rows if row[-1] == 'PublicTest']
csv.writer(open(val_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + val)
print(len(val))
tst = [row[:-1] for row in rows if row[-1] == 'PrivateTest']
csv.writer(open(test_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + tst)
print(len(tst))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
注意:在Windows平臺中,需要在csv.writer()
中加上lineterminator='\n'
不然在生存的csv文件中,每行之間會有空行,影響後續操作。在Linux平臺中不需要這樣做。
step 2: 將csv中的數據轉化成圖片
convert_csv2gray:
# -*- coding: utf-8 -*-
import csv
import os
from PIL import Image
import numpy as np
datasets_path = r'.\datasets'
train_csv = os.path.join(datasets_path, 'train.csv')
val_csv = os.path.join(datasets_path, 'val.csv')
test_csv = os.path.join(datasets_path, 'test.csv')
train_set = os.path.join(datasets_path, 'train')
val_set = os.path.join(datasets_path, 'val')
test_set = os.path.join(datasets_path, 'test')
for save_path, csv_file in [(train_set, train_csv), (val_set, val_csv), (test_set, test_csv)]:
if not os.path.exists(save_path):
os.makedirs(save_path)
num = 1
with open(csv_file) as f:
csvr = csv.reader(f)
header = next(csvr)
for i, (label, pixel) in enumerate(csvr):
pixel = np.asarray([float(p) for p in pixel.split()]).reshape(48, 48)
subfolder = os.path.join(save_path, label)
if not os.path.exists(subfolder):
os.makedirs(subfolder)
im = Image.fromarray(pixel).convert('L')
image_name = os.path.join(subfolder, '{:05d}.jpg'.format(i))
print(image_name)
im.save(image_name)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
生成的數據集目錄結構如下:
第一篇文獻中有網絡結構圖,但根據我做實驗的情況來看,這篇論文水分較大,達不到論文中所說的分類精度。第二篇內容比第一篇詳細很多,很值得參考。
訓練結果: