原文地址:http://blog.csdn.net/hjimce/article/details/49248231
作者:hjimce
caffe對於訓練數據格式,支持:lmdb、h5py……,其中lmdb數據格式常用於單標籤數據,像分類等,經常使用lmdb的數據格式。對於迴歸等問題,或者多標籤數據,一般使用h5py數據的格式。當然好像還有其它格式的數據可用,不過我一般使用這兩種數據格式,因此本文就主要針對這兩種數據格式的製作方法,進行簡單講解。
一、lmdb數據
lmdb用於單標籤數據。爲了簡單起見,我後面通過一個性別分類作爲例子,進行相關數據製作講解。
1、數據準備
首先我們要準備好訓練數據,然後新建一個名爲train的文件夾和一個val的文件夾:
train文件存放訓練數據,val文件存放驗證數據。然後我們在train文件下面,把訓練數據性別爲男、女圖片各放在一個文件夾下面:
同樣的我們在val文件下面也創建文件夾:
兩個文件也是分別存我們用於驗證的圖片數據男女性別文件。我們在test_female下面存放了都是女性的圖片,然後在test_male下面存放的都是驗證數據的男性圖片。
2、標籤文件.txt文件製作.
接着我們需要製作一個train.txt、val.txt文件,這兩個文件分別包含了我們上面的訓練數據的圖片路徑,以及其對應的標籤,如下所示。
我們把女生圖片標號爲1,男生圖片標記爲0。標籤數據文件txt的生成可以通過如下代碼,通過掃描路徑男、女性別下面的圖片,得到標籤文件train.txt和val.txt:
-
<span style="font-family:Arial;font-size:18px;"><span style="font-size:18px;"><span style="font-size:18px;">import os
-
import numpy as np
-
from matplotlib import pyplot as plt
-
import cv2
-
import shutil
-
-
-
-
def GetFileList(FindPath,FlagStr=[]):
-
import os
-
FileList=[]
-
FileNames=os.listdir(FindPath)
-
if len(FileNames)>0:
-
for fn in FileNames:
-
if len(FlagStr)>0:
-
if IsSubString(FlagStr,fn):
-
fullfilename=os.path.join(FindPath,fn)
-
FileList.append(fullfilename)
-
else:
-
fullfilename=os.path.join(FindPath,fn)
-
FileList.append(fullfilename)
-
-
-
if len(FileList)>0:
-
FileList.sort()
-
-
return FileList
-
def IsSubString(SubStrList,Str):
-
flag=True
-
for substr in SubStrList:
-
if not(substr in Str):
-
flag=False
-
-
return flag
-
-
txt=open('train.txt','w')
-
-
imgfile=GetFileList('first_batch/train_female')
-
for img in imgfile:
-
str=img+'\t'+'1'+'\n'
-
txt.writelines(str)
-
-
imgfile=GetFileList('first_batch/train_male')
-
for img in imgfile:
-
str=img+'\t'+'0'+'\n'
-
txt.writelines(str)
-
txt.close()</span></span></span>
把生成的標籤文件,和train\val文件夾放在同一個目錄下面:
需要注意,我們標籤數據文件裏的文件路徑和圖片的路徑要對應的起來,比如val.txt文件的某一行的圖片路徑,是否在val文件夾下面:
3、生成lmdb數據
接着我們的目的就是要通過上面的四個文件,把圖片的數據和其對應的標籤打包起來,打包成lmdb數據格式,打包腳本如下:
-
<span style="font-family:Arial;font-size:18px;"><span style="font-size:18px;">
-
-
-
-
EXAMPLE=.
-
TOOLS=//../build/tools
-
DATA=.
-
-
TRAIN_DATA_ROOT=train/
-
VAL_DATA_ROOT=val/
-
-
-
-
-
-
RESIZE=true
-
if $RESIZE; then
-
RESIZE_HEIGHT=256
-
RESIZE_WIDTH=256
-
else
-
RESIZE_HEIGHT=0
-
RESIZE_WIDTH=0
-
fi
-
-
if [ ! -d "$TRAIN_DATA_ROOT" ]; then
-
echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
-
echo "Set the TRAIN_DATA_ROOT variable in create_imagenet.sh to the path" \
-
"where the ImageNet training data is stored."
-
exit 1
-
fi
-
-
if [ ! -d "$VAL_DATA_ROOT" ]; then
-
echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
-
echo "Set the VAL_DATA_ROOT variable in create_imagenet.sh to the path" \
-
"where the ImageNet validation data is stored."
-
exit 1
-
fi
-
-
echo "Creating train lmdb..."
-
-
GLOG_logtostderr=1 $TOOLS/convert_imageset \
-
--resize_height=$RESIZE_HEIGHT \
-
--resize_width=$RESIZE_WIDTH \
-
--shuffle \
-
$TRAIN_DATA_ROOT \
-
$DATA/train.txt \
-
$EXAMPLE/train_lmdb
-
-
echo "Creating val lmdb..."
-
-
GLOG_logtostderr=1 $TOOLS/convert_imageset \
-
--resize_height=$RESIZE_HEIGHT \
-
--resize_width=$RESIZE_WIDTH \
-
--shuffle \
-
$VAL_DATA_ROOT \
-
$DATA/val.txt \
-
$EXAMPLE/val_lmdb
-
-
echo "Done."</span></span>
通過運行上面的腳本,我們即將得到文件夾train_lmdb\val_lmdb:
我們打開train_lmdb文件夾
並查看一下文件data.mdb數據的大小,如果這個數據包好了我們所有的訓練圖片數據,查一下這個文件的大小是否符合預期大小,如果文件的大小才幾k而已,那麼就代表你沒有打包成功,估計是因爲路徑設置錯誤。我們也可以通過如下的代碼讀取上面打包好的數據,把圖片、和標籤打印出來,查看一下,查看lmdb數據請參考下面的代碼:
Python lmdb數據驗證:
-
<span style="font-family:Arial;font-size:18px;"><span style="font-size:18px;">
-
caffe_root = '/home/hjimce/caffe/'
-
import sys
-
sys.path.insert(0, caffe_root + 'python')
-
import caffe
-
-
import os
-
import lmdb
-
import numpy
-
import matplotlib.pyplot as plt
-
-
-
def readlmdb(path,visualize = False):
-
env = lmdb.open(path, readonly=True,lock=False)
-
-
datum = caffe.proto.caffe_pb2.Datum()
-
x=[]
-
y=[]
-
with env.begin() as txn:
-
cur = txn.cursor()
-
for key, value in cur:
-
-
datum.ParseFromString(value)
-
-
img_data = numpy.array(bytearray(datum.data))\
-
.reshape(datum.channels, datum.height, datum.width)
-
print img_data.shape
-
x.append(img_data)
-
y.append(datum.label)
-
if visualize:
-
img_data=img_data.transpose([1,2,0])
-
img_data = img_data[:, :, ::-1]
-
plt.imshow(img_data)
-
plt.show()
-
print datum.label
-
return x,y</span></span>
通過上面的函數,我們可以是讀取相關的lmdb數據文件。
4、製作均值文件。
這個是爲了圖片歸一化而生成的圖片平均值文件,把所有的圖片相加起來,做平均,具體的腳本如下:
-
-
-
-
-
EXAMPLE=.
-
DATA=train
-
TOOLS=../../build/tools
-
-
$TOOLS/compute_image_mean $EXAMPLE/train_lmdb \ #train_lmdb是我們上面打包好的lmdb數據文件
-
$DATA/imagenet_mean.binaryproto
-
-
echo "Done."
運行這個腳本,我們就可以訓練圖片均值文件:imagenet_mean.binaryproto
至此,我們得到了三個文件:imagenet_mean.binaryproto、train_lmdb、val_lmdb,這三個文件就是我們最後打包好的數據,這些數據我們即將作爲caffe的數據輸入數據格式文件,把這三個文件拷貝出來,就可以把原來還沒有打包好的數據刪了。這三個文件,我們在caffe的網絡結構文件,數據層定義輸入數據的時候,就會用到了:
-
name: "CaffeNet"
-
layers {
-
name: "data"
-
type: DATA
-
top: "data"
-
top: "label"
-
data_param {
-
source: "train_lmdb"#lmbd格式的訓練數據
-
backend: LMDB
-
batch_size: 50
-
}
-
transform_param {
-
crop_size: 227
-
mirror: true
-
mean_file:"imagenet_mean.binaryproto"#均值文件
-
-
}
-
include: { phase: TRAIN }
-
}
-
layers {
-
name: "data"
-
type: DATA
-
top: "data"
-
top: "label"
-
data_param {
-
source: "val_lmdb"#lmdb格式的驗證數據
-
backend: LMDB
-
batch_size: 50
-
}
-
transform_param {
-
crop_size: 227
-
mirror: false
-
mean_file:"imagenet_mean.binaryproto"#均值文件
-
}
-
include: { phase: TEST }
-
}
二、h5py格式數據
上面的lmdb一般用於單標籤數據,圖片分類的時候,大部分用lmdb格式。然而假設我們要搞的項目是人臉特徵點識別,我們要識別出68個人臉特徵點,也就是相當於136維的輸出向量。網上查了一下,對於caffe多標籤輸出,需要使用h5py格式的數據,而且使用h5py的數據格式的時候,caffe是不能使用數據擴充進行相關的數據變換的,很是悲劇啊,所以如果caffe使用h5py數據格式的話,需要自己在外部,進行數據擴充,數據歸一化等相關的數據預處理操作。
1、h5py數據格式生成
下面演示一下數據h5py數據格式的製作:
-
-
caffe_root = '/home/hjimce/caffe/'
-
import sys
-
sys.path.insert(0, caffe_root + 'python')
-
import os
-
import cv2
-
import numpy as np
-
import h5py
-
from common import shuffle_in_unison_scary, processImage
-
import matplotlib.pyplot as plt
-
-
def readdata(filepath):
-
fr=open(filepath,'r')
-
filesplit=[]
-
for line in fr.readlines():
-
s=line.split()
-
s[1:]=[float(x) for x in s[1:]]
-
filesplit.append(s)
-
fr.close()
-
return filesplit
-
-
def sqrtimg(img):
-
height,width=img.shape[:2]
-
maxlenght=max(height,width)
-
sqrtimg0=np.zeros((maxlenght,maxlenght,3),dtype='uint8')
-
-
sqrtimg0[(maxlenght*.5-height*.5):(maxlenght*.5+height*.5),(maxlenght*.5-width*.5):(maxlenght*.5+width*.5)]=img
-
return sqrtimg0
-
-
-
def generate_hdf5():
-
-
labelfile =readdata('../data/my_alige_landmark.txt')
-
F_imgs = []
-
F_landmarks = []
-
-
-
for i,l in enumerate(labelfile):
-
imgpath='../data/'+l[0]
-
-
img=cv2.imread(imgpath)
-
maxx=max(img.shape[0],img.shape[1])
-
img=sqrtimg(img)
-
img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
-
f_face=cv2.resize(img,(39,39))
-
-
plt.imshow(f_face,cmap='gray')
-
-
-
f_face = f_face.reshape((1, 39, 39))
-
f_landmark =np.asarray(l[1:],dtype='float')
-
-
F_imgs.append(f_face)
-
-
-
-
f_landmark=f_landmark/maxx
-
print f_landmark
-
F_landmarks.append(f_landmark)
-
-
-
F_imgs, F_landmarks = np.asarray(F_imgs), np.asarray(F_landmarks)
-
-
-
F_imgs = processImage(F_imgs)
-
shuffle_in_unison_scary(F_imgs, F_landmarks)
-
-
-
with h5py.File(os.getcwd()+ '/train_data.h5', 'w') as f:
-
f['data'] = F_imgs.astype(np.float32)
-
f['landmark'] = F_landmarks.astype(np.float32)
-
-
with open(os.getcwd() + '/train.txt', 'w') as f:
-
f.write(os.getcwd() + '/train_data.h5\n')
-
print i
-
-
-
if __name__ == '__main__':
-
generate_hdf5()
利用上面的代碼,可以生成一個train.txt、train_data.h5的文件,然後在caffe的prototxt中,進行訓練的時候,可以用如下的代碼,作爲數據層的調用:
-
layer {
-
name: "hdf5_train_data"
-
type: "HDF5Data"
-
top: "data"
-
top: "landmark"
-
include {
-
phase: TRAIN
-
}
-
hdf5_data_param {
-
source: "h5py/train.txt"
-
batch_size: 64
-
}
-
}
上面需要注意的是,相比與lmdb的數據格式,我們需要該動的地方,我標註的地方就是需要改動的地方,還有h5py不支持數據變換。
2、h5py數據讀取
-
f=h5py.File('../h5py/train.h5','r')
-
x=f['data'][:]
-
x=np.asarray(x,dtype='float32')
-
y=f['label'][:]
-
y=np.asarray(y,dtype='float32')
-
print x.shape
-
print y.shape
可以通過上面代碼,查看我們生成的.h5格式文件。
在需要注意的是,我們輸入caffe的h5py圖片數據爲四維矩陣(number_samples,nchannels,height,width)的矩陣,標籤矩陣爲二維(number_samples,labels_ndim),同時數據的格式需要轉成float32,用於迴歸任務。
**********************作者:hjimce 時間:2015.10.2 聯繫QQ:1393852684 原創文章,轉載請保留原文地址、作者等信息***************