《 數據加載》
前面幾篇博客分析Pytorch模型訓練的模型定義,損失函數及優化器,本文則來看看在模型訓練中,又一非常重要的模塊—數據加載
在深度學習模型訓練中,我們面對的訓練任務是多種多樣的,不同任務面對的數據格式也是不同的,甚至相同任務,也會面對不同格式的數據集;所以不存在所謂的通用數據腳本,只能是具體任務,具體數據集格式,單獨對待。
但是呢!!!深度學習框架一般都會爲數據加載提供同一的接口,我們實現某個任務訓練時,只需要按照其接口規則,實現我們需求即可。
本文就先來分析下Pytorch中的數據加載邏輯,並拿CPN源碼來進行實例分析。
文章目錄
0 博客目錄
Pytorch模型訓練(0) - CPN源碼解析
Pytorch模型訓練(1) - 模型定義
Pytorch模型訓練(2) - 模型初始化
Pytorch模型訓練(3) - 模型保存與加載
Pytorch模型訓練(4) - Loss Function
Pytorch模型訓練(5) - Optimizer
Pytorch模型訓練(6) - 數據加載
1 數據加載基類–data.Dataset
Pytorch中的data.Dataset類,就是給開發者提供的數據加載接口類,它是一個抽象類,也就是說,我們實現某個任務數據加載,就需要根據我們具體需求重寫這個類,先看看data.Dataset源碼
class Dataset(object):
"""An abstract class representing a Dataset.
數據表示的抽象類
All other datasets should subclass it. All subclasses should override
所有數據類都需要繼承它,並重寫它的方法
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
數據加載方法,核心方法(數據加載,數據處理,數據增強等都將在這裏進行)
def __getitem__(self, index):
raise NotImplementedError
獲取當前數據集長度
def __len__(self):
raise NotImplementedError
不太清楚這個方法,不過一般用不到
def __add__(self, other):
return ConcatDataset([self, other])
2 數據迭代基類–data.DataLoader
上面的data.Dataset類,是供開發者實現特定數據加載的接口,它的功能是接到數據加載命令,加載數據,處理數據,返回數據。
而我們訓練模型時,還需要一個數據迭代器,一個指揮data.Dataset類的指揮者,一個統一管理數據加載的管理者,在Pytorch中就是data.DataLoader類。實際應用中,這個類接受data.Dataset作爲一個參數,實例化一個數據加載迭代器,爲訓練模型服務。
data.DataLoader源碼這裏就不粘貼了,有興趣可以看看,這裏簡述一下;data.DataLoader類在實現時,又提煉出更底層_DataLoaderIter(object)類,實際的方法都在這個類中,比如:
- _next_(self)函數
- _iter_(self)函數
- _process_next_batch(self, batch)函數
- _get_batch(self)函數等
3 數據加載流程
這裏我們假設,我們已經實現來自己的數據加載類MyDataset(data.Dataset),則流程如下:
3.1 初始化
#實例化一個數據迭代器train_loader
train_loader = torch.utils.data.DataLoader(
MyDataset(data_params), #傳入一些關於數據集信息的參數(如數據路徑等),初始化MyDataset
batch_size=32, shuffle=True, #其他參數
num_workers=args.workers, pin_memory=True)
在初始化中,不光會初始化一些參數,它還會將讀取給定的label文件或數據列表文件,將樣本對(數據和label)的索引保存在一個list對象中,方便後面讀取
3.2 迭代
for i, (inputs, targets) in enumerate(train_loader):
。。。。
這一步,就是不斷迭代加載數據進行訓練,而其內部就是怎麼樣的流程呢?
-
調用DataLoader._iter_(self)
-
調用_DataLoaderIter._next_(self),其內部又會調用_get_batch(),_process_next_batch(batch)等函數,但第一個batch會跳過next這一步
-
調用collate_fn,這個函數在中_worker_loop中調用
samples = collate_fn([dataset[i] for i in batch_indices])
-
collate_fn則會找到我們真正加載數據函數:MyDataset._getitem_(self, index)
-
調用_getitem_(self, index)
這個函數纔是我們需要重點重寫對象,我們數據和label的加載,預處理,數據增強,torch格式轉化,數據返回等等,都在這個函數完成。
4 CPN–數據加載
4.1 MscocoMulti類
該源碼是用COCO數據來訓練人體關鍵點
4.1.1 COCO加載流程
1)定義
class MscocoMulti(data.Dataset) #繼承自data.Dataset
2)初始化
def __init__(self, cfg, train=True): #cfg 參數對象
self.img_folder = cfg.img_path #圖像路徑
self.is_train = train #訓練標誌
self.inp_res = cfg.data_shape #模型輸入尺寸 處理圖像和lebel用
self.out_res = cfg.output_shape #模型輸出尺寸 處理label用
self.pixel_means = cfg.pixel_means #均值 圖像預處理用
self.num_class = cfg.num_class #輸出個數
self.cfg = cfg #其他一些參數
self.bbox_extend_factor = cfg.bbox_extend_factor #人體box外擴比例
if train:
self.scale_factor = cfg.scale_factor #圖像縮放比例範圍
self.rot_factor = cfg.rot_factor #圖像旋轉最大度
self.symmetry = cfg.symmetry #對稱點對,用於翻轉
with open(cfg.gt_path) as anno_file:
self.anno = json.load(anno_file) #加載label文件,將label信息保存到名爲anno的list中
3)_len_(self)
def __len__(self):
return len(self.anno) #樣本個數
4)_getitem_(self, index)
def __getitem__(self, index):
#1 獲取樣本信息
a = self.anno[index] #單個樣本,是一個dict,包含當前樣本基本信息
image_name = a['imgInfo']['img_paths'] #當前樣本,對應的圖像名
img_path = os.path.join(self.img_folder, image_name) #圖像路徑
if self.is_train:
#label點shape轉換,輸入:51×1格式 需要:17×3格式
#[x,y,valid],valid表示當前點分類
# COCO visible: 0-no label, 1-label + invisible, 2-label + visible
points = np.array(a['unit']['keypoints']).reshape(self.num_class, 3).astype(np.float32)
gt_bbox = a['unit']['GT_bbox'] #人的外接矩形
#2 讀取圖像
image = scipy.misc.imread(img_path, mode='RGB')
#3 截圖圖像,訓練時,同等變換points,
if self.is_train:
image, points, details = self.augmentationCropImage(image, gt_bbox, points)
else:
image, details = self.augmentationCropImage(image, gt_bbox)
#4 訓練時,數據增強
if self.is_train:
image, points = self.data_augmentation(image, points, a['operation']) #數據增強
img = im_to_torch(image) # CxHxW #轉化爲torch格式 HxWxC ==> CxHxW
# Color dithering 抖動增強
img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
#label歸一化,因爲CPN網絡做了4次下采樣,所以CPN輸出是輸入的1/4
points[:, :2] //= 4 # output size is 1/4 input size
pts = torch.Tensor(points) #交給torch變量
else:
img = im_to_torch(image)
img = color_normalize(img, self.pixel_means) #減去均值
#5 訓練時,根據points生成target熱力圖(網絡輸出的是熱力圖)
if self.is_train:
target15 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
target11 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
target9 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
target7 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
for i in range(self.num_class):
if pts[i, 2] > 0: # COCO visible: 0-no label, 1-label + invisible, 2-label + visible
target15[i] = generate_heatmap(target15[i], pts[i], self.cfg.gk15)
target11[i] = generate_heatmap(target11[i], pts[i], self.cfg.gk11)
target9[i] = generate_heatmap(target9[i], pts[i], self.cfg.gk9)
target7[i] = generate_heatmap(target7[i], pts[i], self.cfg.gk7)
targets = [torch.Tensor(target15), torch.Tensor(target11), torch.Tensor(target9), torch.Tensor(target7)]
valid = pts[:, 2] #點分類信息
#6 訓練樣本信息
meta = {'index' : index, 'imgID' : a['imgInfo']['imgID'],
'GT_bbox' : np.array([gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3]]),
'img_path' : img_path, 'augmentation_details' : details}
#7 返回數據
if self.is_train:
return img, targets, valid, meta
else:
meta['det_scores'] = a['score']
return img, meta
4.1.2 augmentation CropImage
該函數會根據bbox信息將人區域從原圖中截取出來,point也會做相應操作;其中包括邊界填充,人體外接框擴張縮放,截取圖像,縮放圖像,截取位置信息等等
def augmentationCropImage(self, img, bbox, joints=None):
height, width = self.inp_res[0], self.inp_res[1]
bbox = np.array(bbox).reshape(4, ).astype(np.float32)
add = max(img.shape[0], img.shape[1])
mean_value = self.pixel_means
bimg = cv2.copyMakeBorder(img, add, add, add, add, borderType=cv2.BORDER_CONSTANT, value=mean_value.tolist())
objcenter = np.array([(bbox[0] + bbox[2]) / 2., (bbox[1] + bbox[3]) / 2.])
bbox += add
objcenter += add
if self.is_train:
joints[:, :2] += add
inds = np.where(joints[:, -1] == 0)
joints[inds, :2] = -1000000 # avoid influencing by data processing
crop_width = (bbox[2] - bbox[0]) * (1 + self.bbox_extend_factor[0] * 2)
crop_height = (bbox[3] - bbox[1]) * (1 + self.bbox_extend_factor[1] * 2)
if self.is_train:
crop_width = crop_width * (1 + 0.25)
crop_height = crop_height * (1 + 0.25)
if crop_height / height > crop_width / width:
crop_size = crop_height
min_shape = height
else:
crop_size = crop_width
min_shape = width
crop_size = min(crop_size, objcenter[0] / width * min_shape * 2. - 1.)
crop_size = min(crop_size, (bimg.shape[1] - objcenter[0]) / width * min_shape * 2. - 1)
crop_size = min(crop_size, objcenter[1] / height * min_shape * 2. - 1.)
crop_size = min(crop_size, (bimg.shape[0] - objcenter[1]) / height * min_shape * 2. - 1)
min_x = int(objcenter[0] - crop_size / 2. / min_shape * width)
max_x = int(objcenter[0] + crop_size / 2. / min_shape * width)
min_y = int(objcenter[1] - crop_size / 2. / min_shape * height)
max_y = int(objcenter[1] + crop_size / 2. / min_shape * height)
x_ratio = float(width) / (max_x - min_x)
y_ratio = float(height) / (max_y - min_y)
if self.is_train:
joints[:, 0] = joints[:, 0] - min_x
joints[:, 1] = joints[:, 1] - min_y
joints[:, 0] *= x_ratio
joints[:, 1] *= y_ratio
label = joints[:, :2].copy()
valid = joints[:, 2].copy()
img = cv2.resize(bimg[min_y:max_y, min_x:max_x, :], (width, height))
details = np.asarray([min_x - add, min_y - add, max_x - add, max_y - add]).astype(np.float)
if self.is_train:
return img, joints, details
else:
return img, details
4.1.3 data_augmentation
該函數將對訓練樣本進行數據增強操作,這裏包括隨機縮放,隨機翻轉,隨機旋轉等
def data_augmentation(self, img, label, operation):
height, width = img.shape[0], img.shape[1]
center = (width / 2., height / 2.)
n = label.shape[0]
affrat = random.uniform(self.scale_factor[0], self.scale_factor[1])
halfl_w = min(width - center[0], (width - center[0]) / 1.25 * affrat)
halfl_h = min(height - center[1], (height - center[1]) / 1.25 * affrat)
img = skimage.transform.resize(img[int(center[1] - halfl_h): int(center[1] + halfl_h + 1),
int(center[0] - halfl_w): int(center[0] + halfl_w + 1)], (height, width))
for i in range(n):
label[i][0] = (label[i][0] - center[0]) / halfl_w * (width - center[0]) + center[0]
label[i][1] = (label[i][1] - center[1]) / halfl_h * (height - center[1]) + center[1]
label[i][2] *= (
(label[i][0] >= 0) & (label[i][0] < width) & (label[i][1] >= 0) & (label[i][1] < height))
# flip augmentation
if operation == 1:
img = cv2.flip(img, 1)
cod = []
allc = []
for i in range(n):
x, y = label[i][0], label[i][1]
if x >= 0:
x = width - 1 - x
cod.append((x, y, label[i][2]))
# **** the joint index depends on the dataset ****
for (q, w) in self.symmetry:
cod[q], cod[w] = cod[w], cod[q]
for i in range(n):
allc.append(cod[i][0])
allc.append(cod[i][1])
allc.append(cod[i][2])
label = np.array(allc).reshape(n, 3)
# rotated augmentation
if operation > 1:
angle = random.uniform(0, self.rot_factor)
if random.randint(0, 1):
angle *= -1
rotMat = cv2.getRotationMatrix2D(center, angle, 1.0)
img = cv2.warpAffine(img, rotMat, (width, height))
allc = []
for i in range(n):
x, y = label[i][0], label[i][1]
v = label[i][2]
coor = np.array([x, y])
if x >= 0 and y >= 0:
R = rotMat[:, : 2]
W = np.array([rotMat[0][2], rotMat[1][2]])
coor = np.dot(R, coor) + W
allc.append(int(coor[0]))
allc.append(int(coor[1]))
v *= ((coor[0] >= 0) & (coor[0] < width) & (coor[1] >= 0) & (coor[1] < height))
allc.append(int(v))
label = np.array(allc).reshape(n, 3).astype(np.int)
return img, label
4-2 訓練循環加載
1)實例化數據加載器
train_loader = torch.utils.data.DataLoader(
MscocoMulti(cfg),
batch_size=cfg.batch_size*args.num_gpus, shuffle=True,
num_workers=args.workers, pin_memory=True)
2)epoch循環
for epoch in range(args.start_epoch, args.epochs):
lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma)
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
# train for one epoch
train_loss = train(train_loader, model, [criterion1, criterion2], optimizer)
print('train_loss: ',train_loss)
3)batch循環
#調用上面的__getitem__(self, index),加載數據
for i, (inputs, targets, valid, meta) in enumerate(train_loader):
input_var = torch.autograd.Variable(inputs.cuda())
target15, target11, target9, target7 = targets
refine_target_var = torch.autograd.Variable(target7.cuda(async=True))
valid_var = torch.autograd.Variable(valid.cuda(async=True))
# compute output
global_outputs, refine_output = model(input_var)
score_map = refine_output.data.cpu()