Pytorch模型训练(6) - 数据加载

《 数据加载》

  前面几篇博客分析Pytorch模型训练的模型定义,损失函数及优化器,本文则来看看在模型训练中,又一非常重要的模块—数据加载
  在深度学习模型训练中,我们面对的训练任务是多种多样的,不同任务面对的数据格式也是不同的,甚至相同任务,也会面对不同格式的数据集;所以不存在所谓的通用数据脚本,只能是具体任务,具体数据集格式,单独对待。
  但是呢!!!深度学习框架一般都会为数据加载提供同一的接口,我们实现某个任务训练时,只需要按照其接口规则,实现我们需求即可。
  本文就先来分析下Pytorch中的数据加载逻辑,并拿CPN源码来进行实例分析。

0 博客目录

Pytorch模型训练(0) - CPN源码解析
Pytorch模型训练(1) - 模型定义
Pytorch模型训练(2) - 模型初始化
Pytorch模型训练(3) - 模型保存与加载
Pytorch模型训练(4) - Loss Function
Pytorch模型训练(5) - Optimizer
Pytorch模型训练(6) - 数据加载

1 数据加载基类–data.Dataset

  Pytorch中的data.Dataset类,就是给开发者提供的数据加载接口类,它是一个抽象类,也就是说,我们实现某个任务数据加载,就需要根据我们具体需求重写这个类,先看看data.Dataset源码

class Dataset(object):
    """An abstract class representing a Dataset.   
    数据表示的抽象类
    
    All other datasets should subclass it. All subclasses should override   
    所有数据类都需要继承它,并重写它的方法
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    
	数据加载方法,核心方法(数据加载,数据处理,数据增强等都将在这里进行)
    def __getitem__(self, index):
        raise NotImplementedError
	
	 获取当前数据集长度
    def __len__(self):
        raise NotImplementedError
   
	不太清楚这个方法,不过一般用不到
    def __add__(self, other):
        return ConcatDataset([self, other])

2 数据迭代基类–data.DataLoader

  上面的data.Dataset类,是供开发者实现特定数据加载的接口,它的功能是接到数据加载命令,加载数据,处理数据,返回数据。
  而我们训练模型时,还需要一个数据迭代器,一个指挥data.Dataset类的指挥者,一个统一管理数据加载的管理者,在Pytorch中就是data.DataLoader类。实际应用中,这个类接受data.Dataset作为一个参数,实例化一个数据加载迭代器,为训练模型服务。
  data.DataLoader源码这里就不粘贴了,有兴趣可以看看,这里简述一下;data.DataLoader类在实现时,又提炼出更底层_DataLoaderIter(object)类,实际的方法都在这个类中,比如:

  • _next_(self)函数
  • _iter_(self)函数
  • _process_next_batch(self, batch)函数
  • _get_batch(self)函数等

3 数据加载流程

  这里我们假设,我们已经实现来自己的数据加载类MyDataset(data.Dataset),则流程如下:

3.1 初始化

#实例化一个数据迭代器train_loader
 train_loader = torch.utils.data.DataLoader(
    MyDataset(data_params),     #传入一些关于数据集信息的参数(如数据路径等),初始化MyDataset
    batch_size=32, shuffle=True,  #其他参数
    num_workers=args.workers, pin_memory=True) 

  在初始化中,不光会初始化一些参数,它还会将读取给定的label文件或数据列表文件,将样本对(数据和label)的索引保存在一个list对象中,方便后面读取

3.2 迭代

for i, (inputs, targets) in enumerate(train_loader): 
	。。。。

  这一步,就是不断迭代加载数据进行训练,而其内部就是怎么样的流程呢?

  1. 调用DataLoader._iter_(self)

  2. 调用_DataLoaderIter._next_(self),其内部又会调用_get_batch(),_process_next_batch(batch)等函数,但第一个batch会跳过next这一步

  3. 调用collate_fn,这个函数在worker.pyworker.py中_worker_loop中调用

    samples = collate_fn([dataset[i] for i in batch_indices])

  4. collate_fn则会找到我们真正加载数据函数:MyDataset._getitem_(self, index)

  5. 调用_getitem_(self, index)
    这个函数才是我们需要重点重写对象,我们数据和label的加载,预处理,数据增强,torch格式转化,数据返回等等,都在这个函数完成。

4 CPN–数据加载

4.1 MscocoMulti类

  该源码是用COCO数据来训练人体关键点

4.1.1 COCO加载流程

  1)定义

class MscocoMulti(data.Dataset)  #继承自data.Dataset

  2)初始化

    def __init__(self, cfg, train=True):                    #cfg 参数对象
        self.img_folder = cfg.img_path                      #图像路径
        self.is_train = train                               #训练标志
        self.inp_res = cfg.data_shape                       #模型输入尺寸  处理图像和lebel用
        self.out_res = cfg.output_shape                     #模型输出尺寸 处理label用
        self.pixel_means = cfg.pixel_means                  #均值  图像预处理用
        self.num_class = cfg.num_class                      #输出个数 
        self.cfg = cfg                                      #其他一些参数
        self.bbox_extend_factor = cfg.bbox_extend_factor    #人体box外扩比例
        if train:
            self.scale_factor = cfg.scale_factor            #图像缩放比例范围
            self.rot_factor = cfg.rot_factor                #图像旋转最大度
            self.symmetry = cfg.symmetry                    #对称点对,用于翻转
        with open(cfg.gt_path) as anno_file:   
            self.anno = json.load(anno_file)          #加载label文件,将label信息保存到名为anno的list中

  3)_len_(self)

 def __len__(self):
    return len(self.anno)   #样本个数

  4)_getitem_(self, index)

    def __getitem__(self, index):
    	#1 获取样本信息
        a = self.anno[index]        #单个样本,是一个dict,包含当前样本基本信息
        image_name = a['imgInfo']['img_paths']   #当前样本,对应的图像名
        img_path = os.path.join(self.img_folder, image_name)   #图像路径
        if self.is_train:
        	#label点shape转换,输入:51×1格式     需要:17×3格式
        	#[x,y,valid],valid表示当前点分类
        	# COCO visible: 0-no label, 1-label + invisible, 2-label + visible
            points = np.array(a['unit']['keypoints']).reshape(self.num_class, 3).astype(np.float32) 
        gt_bbox = a['unit']['GT_bbox']     #人的外接矩形
        
        #2 读取图像 
        image = scipy.misc.imread(img_path, mode='RGB') 
        
        #3 截图图像,训练时,同等变换points,
        if self.is_train:
            image, points, details = self.augmentationCropImage(image, gt_bbox, points)
        else:
            image, details = self.augmentationCropImage(image, gt_bbox)

		#4 训练时,数据增强
        if self.is_train:
            image, points = self.data_augmentation(image, points, a['operation'])    #数据增强
            img = im_to_torch(image)  # CxHxW     #转化为torch格式  HxWxC ==>  CxHxW
            
            # Color dithering  抖动增强
            img[0, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
            img[1, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
            img[2, :, :].mul_(random.uniform(0.8, 1.2)).clamp_(0, 1)
			
			#label归一化,因为CPN网络做了4次下采样,所以CPN输出是输入的1/4
            points[:, :2] //= 4 # output size is 1/4 input size
            pts = torch.Tensor(points)  #交给torch变量
        else:
            img = im_to_torch(image)
		
        img = color_normalize(img, self.pixel_means)  #减去均值
		
		#5 训练时,根据points生成target热力图(网络输出的是热力图)
        if self.is_train:
            target15 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
            target11 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
            target9 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
            target7 = np.zeros((self.num_class, self.out_res[0], self.out_res[1]))
            for i in range(self.num_class):
                if pts[i, 2] > 0: # COCO visible: 0-no label, 1-label + invisible, 2-label + visible
                    target15[i] = generate_heatmap(target15[i], pts[i], self.cfg.gk15)
                    target11[i] = generate_heatmap(target11[i], pts[i], self.cfg.gk11)
                    target9[i] = generate_heatmap(target9[i], pts[i], self.cfg.gk9)
                    target7[i] = generate_heatmap(target7[i], pts[i], self.cfg.gk7)
                    
            targets = [torch.Tensor(target15), torch.Tensor(target11), torch.Tensor(target9), torch.Tensor(target7)]
            valid = pts[:, 2]  #点分类信息
		
		#6 训练样本信息
        meta = {'index' : index, 'imgID' : a['imgInfo']['imgID'], 
        'GT_bbox' : np.array([gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3]]), 
        'img_path' : img_path, 'augmentation_details' : details}

		#7 返回数据
        if self.is_train:
            return img, targets, valid, meta
        else:
            meta['det_scores'] = a['score']
            return img, meta

4.1.2 augmentation CropImage

  该函数会根据bbox信息将人区域从原图中截取出来,point也会做相应操作;其中包括边界填充,人体外接框扩张缩放,截取图像,缩放图像,截取位置信息等等

def augmentationCropImage(self, img, bbox, joints=None):  
    height, width = self.inp_res[0], self.inp_res[1]
    bbox = np.array(bbox).reshape(4, ).astype(np.float32)
    add = max(img.shape[0], img.shape[1])
    mean_value = self.pixel_means
    bimg = cv2.copyMakeBorder(img, add, add, add, add, borderType=cv2.BORDER_CONSTANT, value=mean_value.tolist())
    objcenter = np.array([(bbox[0] + bbox[2]) / 2., (bbox[1] + bbox[3]) / 2.])      
    bbox += add
    objcenter += add
    if self.is_train:
        joints[:, :2] += add
        inds = np.where(joints[:, -1] == 0)
        joints[inds, :2] = -1000000 # avoid influencing by data processing
    crop_width = (bbox[2] - bbox[0]) * (1 + self.bbox_extend_factor[0] * 2)
    crop_height = (bbox[3] - bbox[1]) * (1 + self.bbox_extend_factor[1] * 2)
    if self.is_train:
        crop_width = crop_width * (1 + 0.25)
        crop_height = crop_height * (1 + 0.25)  
    if crop_height / height > crop_width / width:
        crop_size = crop_height
        min_shape = height
    else:
        crop_size = crop_width
        min_shape = width  

    crop_size = min(crop_size, objcenter[0] / width * min_shape * 2. - 1.)
    crop_size = min(crop_size, (bimg.shape[1] - objcenter[0]) / width * min_shape * 2. - 1)
    crop_size = min(crop_size, objcenter[1] / height * min_shape * 2. - 1.)
    crop_size = min(crop_size, (bimg.shape[0] - objcenter[1]) / height * min_shape * 2. - 1)

    min_x = int(objcenter[0] - crop_size / 2. / min_shape * width)
    max_x = int(objcenter[0] + crop_size / 2. / min_shape * width)
    min_y = int(objcenter[1] - crop_size / 2. / min_shape * height)
    max_y = int(objcenter[1] + crop_size / 2. / min_shape * height)                               

    x_ratio = float(width) / (max_x - min_x)
    y_ratio = float(height) / (max_y - min_y)

    if self.is_train:
        joints[:, 0] = joints[:, 0] - min_x
        joints[:, 1] = joints[:, 1] - min_y

        joints[:, 0] *= x_ratio
        joints[:, 1] *= y_ratio
        label = joints[:, :2].copy()
        valid = joints[:, 2].copy()

    img = cv2.resize(bimg[min_y:max_y, min_x:max_x, :], (width, height))  
    details = np.asarray([min_x - add, min_y - add, max_x - add, max_y - add]).astype(np.float)

    if self.is_train:
        return img, joints, details
    else:
        return img, details

4.1.3 data_augmentation

  该函数将对训练样本进行数据增强操作,这里包括随机缩放,随机翻转,随机旋转等

    def data_augmentation(self, img, label, operation):
        height, width = img.shape[0], img.shape[1]
        center = (width / 2., height / 2.)
        n = label.shape[0]
        affrat = random.uniform(self.scale_factor[0], self.scale_factor[1])
        
        halfl_w = min(width - center[0], (width - center[0]) / 1.25 * affrat)
        halfl_h = min(height - center[1], (height - center[1]) / 1.25 * affrat)
        img = skimage.transform.resize(img[int(center[1] - halfl_h): int(center[1] + halfl_h + 1),
                             int(center[0] - halfl_w): int(center[0] + halfl_w + 1)], (height, width))
        for i in range(n):
            label[i][0] = (label[i][0] - center[0]) / halfl_w * (width - center[0]) + center[0]
            label[i][1] = (label[i][1] - center[1]) / halfl_h * (height - center[1]) + center[1]
            label[i][2] *= (
            (label[i][0] >= 0) & (label[i][0] < width) & (label[i][1] >= 0) & (label[i][1] < height))

        # flip augmentation
        if operation == 1:
            img = cv2.flip(img, 1)
            cod = []
            allc = []
            for i in range(n):
                x, y = label[i][0], label[i][1]
                if x >= 0:
                    x = width - 1 - x
                cod.append((x, y, label[i][2]))
            # **** the joint index depends on the dataset ****    
            for (q, w) in self.symmetry:
                cod[q], cod[w] = cod[w], cod[q]
            for i in range(n):
                allc.append(cod[i][0])
                allc.append(cod[i][1])
                allc.append(cod[i][2])
            label = np.array(allc).reshape(n, 3)

        # rotated augmentation
        if operation > 1:      
            angle = random.uniform(0, self.rot_factor)
            if random.randint(0, 1):
                angle *= -1
            rotMat = cv2.getRotationMatrix2D(center, angle, 1.0)
            img = cv2.warpAffine(img, rotMat, (width, height))
            
            allc = []
            for i in range(n):
                x, y = label[i][0], label[i][1]
                v = label[i][2]
                coor = np.array([x, y])
                if x >= 0 and y >= 0:
                    R = rotMat[:, : 2]
                    W = np.array([rotMat[0][2], rotMat[1][2]])
                    coor = np.dot(R, coor) + W
                allc.append(int(coor[0]))
                allc.append(int(coor[1]))
                v *= ((coor[0] >= 0) & (coor[0] < width) & (coor[1] >= 0) & (coor[1] < height))
                allc.append(int(v))
            label = np.array(allc).reshape(n, 3).astype(np.int)
        return img, label

4-2 训练循环加载

  1)实例化数据加载器

train_loader = torch.utils.data.DataLoader(
    MscocoMulti(cfg),
    batch_size=cfg.batch_size*args.num_gpus, shuffle=True,
    num_workers=args.workers, pin_memory=True) 

  2)epoch循环

 for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 

        # train for one epoch
        train_loss = train(train_loader, model, [criterion1, criterion2], optimizer)
        
        print('train_loss: ',train_loss)

  3)batch循环

#调用上面的__getitem__(self, index),加载数据
 for i, (inputs, targets, valid, meta) in enumerate(train_loader):     
 
    input_var = torch.autograd.Variable(inputs.cuda())
    
    target15, target11, target9, target7 = targets
    refine_target_var = torch.autograd.Variable(target7.cuda(async=True))
    valid_var = torch.autograd.Variable(valid.cuda(async=True))

    # compute output
    global_outputs, refine_output = model(input_var)
    score_map = refine_output.data.cpu()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章