hourglass pytorch 實現

主要分爲幾塊

1、數據集讀取

2、hg-model

3、training

4、代碼主要來自於github上幾個 大佬的 代碼的結合 @bearpaw 以及 @roytseng-tw 的訓練代碼和 @anibali 的evaluation代碼, 主要這兩位 大佬的代碼 基本上和使用lua在torch7上的作者源代碼沒有什麼出入,是很好的復現

5、同時採用了hourglass原作者的  訓練集  驗證集  測試集  @umich-vl

7、同時我也會在github上放出caffe版本的hourglass實現,這個主要來自於RMPE這個論文的github。

8、我目前訓練結果在 MPII驗證集上 只能達到  89.3       閾值0.5

一、數據讀取 

        1、數據增廣

            這裏主要涉及到了  crop 、scale 、flip 、rotate這幾個操作

""" Random """
def randn():
    return random.gauss(0, 1)

def rand():
    return random.random()

def rnd(x):
    '''umich hourglass mpii random function'''
    return max(-2 * x, min(2 * x, randn() * x))


""" Visualization """
def show_sample(img, label):  # FIXME: color blending is not right, diff color for each joint
    nJoints = label.shape[0]
    white = np.ones((4,) + img.shape[1:3])
    new_img = white.copy()
    new_img[:3] = img * 0.5
    for i in range(nJoints):
        new_img += 0.5 * white * sktf.resize(label[i], img.shape[1:3], preserve_range=True)
        # print(label[i].max())
        # plt.subplot(121)
        # plt.imshow(np.transpose(new_img, [1, 2, 0]))
        # plt.subplot(122)
        # plt.imshow(label[i])
        # plt.show()
    return np.transpose(new_img, [1, 2, 0])


""" Label """
def create_label(imsize, pt, sigma, distro_type='Gaussian'):
    label = np.zeros(imsize)
    # Check that any part of the distro is in-bounds
    ul = np.math.floor(pt[0] - 3 * sigma), np.math.floor(pt[1] - 3 * sigma)
    br = np.math.floor(pt[0] + 3 * sigma), np.math.floor(pt[1] + 3 * sigma)
    # If not, return the blank label
    if ul[0] >= imsize[1] or ul[1] >= imsize[0] or br[0] < 0 or br[1] < 0:
        return label

    # Generate distro
    size = 6 * sigma + 1
    x = np.arange(0, size, 1, float)
    y = x[:, np.newaxis]
    x0 = y0 = size // 2
    '''Note:
    original torch impl: `local g = image.gaussian(size)`
    equals to `gaussian(size, sigma=0.25*size)` here
    '''
    if distro_type == 'Gaussian':
        distro = np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    elif distro_type == 'Cauchy':  # IS THIS CORRECT ???
        distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
        # distro = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) * np.pi)

    # Usable distro range
    distro_x = max(0, -ul[0]), min(br[0], imsize[1]) - ul[0]
    distro_y = max(0, -ul[1]), min(br[1], imsize[0]) - ul[1]
    assert (distro_x[0] >= 0 and distro_y[0] >= 0), '{}, {}'.format(distro_x, distro_y)
    # label range
    label_x = max(0, ul[0]), min(br[0], imsize[1])
    label_y = max(0, ul[1]), min(br[1], imsize[0])
    label[label_y[0]:label_y[1], label_x[0]:label_x[1]] = \
        distro[distro_y[0]:distro_y[1], distro_x[0]:distro_x[1]]
    return label


""" Flip """
def fliplr_labels(labels, matchedParts, joint_dim=1, width_dim=3):
    """fliplr the joint labels, defaults (B, C, H, W)
    """
    # flip horizontally
    labels = np.flip(labels, axis=width_dim)
    # Change left-right parts
    perm = np.arange(labels.shape[joint_dim])
    for i, j in matchedParts:
        perm[i] = j
        perm[j] = i
    labels = np.take(labels, perm, axis=joint_dim)
    return labels

def fliplr_coords(pts, width, matchedParts):
    # Flip horizontally (only flip valid points)
    pts = np.array([(width - x, y) if x > 0 else (x, y) for x, y in pts])
    # Change left-right parts
    perm = np.arange(pts.shape[0])
    for i, j in matchedParts:
        perm[i] = j
        perm[j] = i
    pts = pts[perm]
    return pts


""" Transform, Crop """
def get_transform(center, scale, rot, res, invert=False):
    '''Prepare transformation matrix (scale, rot).
    '''
    h = 200 * scale
    t = np.eye(3)  # transformation matrix
    # scale
    t[0, 0] = res[1] / h
    t[1, 1] = res[0] / h
    # translation
    t[0, 2] = res[1] * (-center[0] / h + .5)
    t[1, 2] = res[0] * (-center[1] / h + .5)
    # rotation
    if rot != 0:
        rot = -rot  # To match direction of rotation from cropping
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[:2, :2] = [[cs, -sn],
                           [sn, cs]]
        rot_mat[2, 2] = 1
        # Need to make sure rotation is around center
        t_mat = np.eye(3)
        t_mat[0, 2] = -res[1] / 2
        t_mat[1, 2] = -res[0] / 2
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
    if invert:
        t = np.linalg.inv(t)
    return t

def transform(pts, center, scale, rot, res, invert=False):
    """ Transform points from original coord to new coord
    pts: 2 * n array
    """
    t = get_transform(center, scale, rot, [res, res], invert)
    pts = np.array(pts)
    assert pts.shape[0] == 2, pts.shape
    if pts.ndim == 1:
        pts = np.array([pts[0], pts[1], 1])
    else:
        pts = np.concatenate([pts, np.ones((1, pts.shape[1]))], axis=0)
    new_pt = np.dot(t, pts)
    return new_pt[:2].astype(int)

def crop(img, center, scale, rot, res):
    '''
    res: single value of targeted output image resolution
    rot: in degrees
    '''
    # Preprocessing for efficient cropping
    ht, wd = img.shape[0], img.shape[1]
    # print(center, scale, rot, ht, wd)
    sf = scale * 200.0 / res
    # print(sf)
    if sf < 2:
        sf = 1
    else:
        new_size = int(np.math.floor(max(ht, wd) / sf))
        new_ht = int(np.math.floor(ht / sf))
        new_wd = int(np.math.floor(wd / sf))
        if new_size < 2:
            # Zoomed out so much that the image is now a single pixel or less
            return np.zeros(res, res) if img.ndim == 2 \
                else np.zeros(res, res, img.shape[2])
        else:
            img = sktf.resize(img, [new_ht, new_wd], preserve_range=True)
            ht, wd = img.shape[0], img.shape[1]
    # print(ht, wd)
    # Calculate upper left and bottom right coordinates defining crop region
    center = center / sf
    scale = scale / sf
    # print(center, scale)
    ul = transform([0, 0], center, scale, 0, res, invert=True)
    br = transform([res, res], center, scale, 0, res, invert=True)
    if sf >= 2:
         br += - (br - ul - res)
    # print(ul, br)
    # Padding so that when rotated proper amount of context is included
    pad = np.math.ceil(np.linalg.norm(br - ul) / 2 - (br[0] - ul[0]) / 2)
    # print(pad)
    if rot != 0:
        ul -= pad
        br += pad
    # print(ul, br)
    # Define the range of pixels to take from the old image
    old_x = max(0, ul[0]), min(br[0], wd)
    old_y = max(0, ul[1]), min(br[1], ht)
    # print(old_x, old_y)
    # And where to put them in the new image
    new_x = max(0, -ul[0]), min(br[0], wd) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], ht) - ul[1]
    # print(new_x, new_y)
    # Initialize new image and copy pixels over
    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    # print(new_shape)
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]
    new_img = np.zeros(new_shape)
    new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]

    if rot != 0:
        # Rotate the image and remove padded area
        new_img = sktf.rotate(new_img, rot, preserve_range=True)
        new_img = new_img[pad:-pad, pad:-pad]

    if sf < 2:
        new_img = sktf.resize(new_img, [res, res], preserve_range=True)

    return new_img

        2、針對數據集去讀取數據batch

            結合這個 腳本以及上面的數據增廣腳本兩個腳本基本上完成了全部的 數據操作。

class MPII_Dataset(torch.utils.data.Dataset):
    def __init__(self, data_root, split,
                 inp_res=256, out_res=64, sigma=1,
                 scale_factor=0.25, rot_factor=30, return_meta=False, small_image=True):
        self.data_root = data_root
        self.split = split
        self.inp_res = inp_res
        self.out_res = out_res
        self.sigma = sigma
        self.scale_factor = scale_factor
        self.rot_factor = rot_factor
        self.return_meta = return_meta
        self.small_image = small_image

        self.nJoints = 16
        self.accIdxs = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]  # joint idxs for accuracy calculation
        self.flipRef = [[0, 5],   [1, 4],   [2, 3],   # noqa
                        [10, 15], [11, 14], [12, 13]]

        self.annot = {}
        tags = ['imgname', 'part', 'center', 'scale']
        f = h5py.File('{}/mpii/{}.h5'.format(data_root, split), 'r')
        for tag in tags:
            self.annot[tag] = np.asarray(f[tag]).copy()
        f.close()

    def _getPartInfo(self, index):
        # get a COPY
        pts = self.annot['part'][index].copy()
        c = self.annot['center'][index].copy()
        s = self.annot['scale'][index].copy()

        # Small adjustment so cropping is less likely to take feet out
        c[1] = c[1] + 15 * s
        s = s * 1.25
        return pts, c, s

    def _loadImage(self, index):
        impath = os.path.join(self.data_root, 'mpii/images', self.annot['imgname'][index].decode('utf-8'))
        im = skim.img_as_float(skio.imread(impath))
        return im

    def __getitem__(self, index):
        im = self._loadImage(index)
        pts, c, s = self._getPartInfo(index)
        r = 0
        if self.split == 'train':
            # scale and rotation
            s = s * (2 ** rnd(self.scale_factor))
            r = 0 if rand() < 0.6 else rnd(self.rot_factor)
            # flip LR
            if rand() < 0.5:
                im = im[:, ::-1, :]
                pts = fliplr_coords(pts, width=im.shape[1], matchedParts=self.flipRef)
                c[0] = im.shape[1] - c[0]  # flip center point also
            # Color jitter
            im = np.clip(im * np.random.uniform(0.6, 1.4, size=3), 0, 1)
        # Prepare image
        im = crop(im, c, s, r, self.inp_res)
        if im.ndim == 2:
            im = np.tile(im, [1, 1, 3])
        if self.small_image:
            # small size image
            im_s = sktf.resize(im, [self.out_res, self.out_res], preserve_range=True)

        # (h, w, c) to (c, h, w)
        im = np.transpose(im, [2, 0, 1])
        if self.small_image:
            im_s = np.transpose(im_s, [2, 0, 1])

        # Prepare label
        labels = np.zeros((self.nJoints, self.out_res, self.out_res))
        new_pts = transform(pts.T, c, s, r, self.out_res).T
        for i in range(self.nJoints):
            if pts[i, 0] > 0:
                labels[i] = create_label(
                    labels.shape[1:],
                    new_pts[i],
                    self.sigma)

        ret_list = [im.astype(np.float32), labels.astype(np.float32)]
        if self.small_image:
            ret_list.append(im_s)
        if self.return_meta:
            meta = [pts, c, s, r]
            ret_list.append(meta)
        return tuple(ret_list)

    def __len__(self):
        return len(self.annot['imgname'])

二、模型代碼

        1、首先我們先去把 殘差網絡的基本模塊定義一下

class HgResBlock(nn.Module):
    ''' Hourglass residual block '''
    def __init__(self, inplanes, outplanes, stride=1):
        super().__init__()
        self.inplanes = inplanes
        self.outplanes = outplanes
        midplanes = outplanes // 2
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, midplanes, 1, stride)  # bias=False
        self.bn2 = nn.BatchNorm2d(midplanes)
        self.conv2 = nn.Conv2d(midplanes, midplanes, 3, stride, 1)
        self.bn3 = nn.BatchNorm2d(midplanes)
        self.conv3 = nn.Conv2d(midplanes, outplanes, 1, stride)  # bias=False
        self.relu = nn.ReLU(inplace=True)
        if inplanes != outplanes:
            self.conv_skip = nn.Conv2d(inplanes, outplanes, 1, 1)

    def forward(self, x):
        residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        if self.inplanes != self.outplanes:
            residual = self.conv_skip(residual)
        out += residual
        return out

            2、定義hourglass基本結構

class Hourglass(nn.Module):
    def __init__(self, depth, nFeat, nModules, resBlock):
        super().__init__()
        self.depth = depth
        self.nFeat = nFeat
        self.nModules = nModules  # num residual modules per location
        self.resBlock = resBlock

        self.hg = self._make_hour_glass()
        self.downsample = nn.MaxPool2d(2, 2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def _make_hour_glass(self):
        hg = []
        for i in range(self.depth):
            res = [self._make_residual(self.nModules) for _ in range(3)]  # skip(upper branch); down_path, up_path(lower branch)
            if i == (self.depth - 1):
                res.append(self._make_residual(self.nModules))  # extra one for the middle
            hg.append(nn.ModuleList(res))
        return nn.ModuleList(hg)

    def _make_residual(self, n):
        return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])

    def forward(self, x):
        return self._hour_glass_forward(0, x)

    def _hour_glass_forward(self, depth_id, x):
        up1 = self.hg[depth_id][0](x)
        low1 = self.downsample(x)
        low1 = self.hg[depth_id][1](low1)
        if depth_id == (self.depth - 1):
            low2 = self.hg[depth_id][3](low1)
        else:
            low2 = self._hour_glass_forward(depth_id + 1, low1)
        low3 = self.hg[depth_id][2](low2)
        up2 = self.upsample(low3)
        return up1 + up2


class HourglassNet(nn.Module):
    '''Hourglass model from Newell et al ECCV 2016'''
    def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
        super().__init__()
        self.nStacks = nStacks
        self.nModules = nModules
        self.nFeat = nFeat
        self.nClasses = nClasses
        self.resBlock = resBlock
        self.inplanes = inplanes

        self._make_head()

        hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
        for i in range(nStacks):
            hg.append(Hourglass(4, nFeat, nModules, resBlock))
            res.append(self._make_residual(nModules))
            fc.append(self._make_fc(nFeat, nFeat))
            score.append(nn.Conv2d(nFeat, nClasses, 1))
            if i < (nStacks - 1):
                fc_.append(nn.Conv2d(nFeat, nFeat, 1))
                score_.append(nn.Conv2d(nClasses, nFeat, 1))
        self.hg = nn.ModuleList(hg)
        self.res = nn.ModuleList(res)
        self.fc = nn.ModuleList(fc)
        self.score = nn.ModuleList(score)
        self.fc_ = nn.ModuleList(fc_)
        self.score_ = nn.ModuleList(score_)

    def _make_head(self):
        self.conv1 = nn.Conv2d(self.inplanes, 64, 7, 2, 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.res1 = self.resBlock(64, 128)
        self.pool = nn.MaxPool2d(2, 2)
        self.res2 = self.resBlock(128, 128)
        self.res3 = self.resBlock(128, self.nFeat)

    def _make_residual(self, n):
        return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])

    def _make_fc(self, inplanes, outplanes):
        return nn.Sequential(
            nn.Conv2d(inplanes, outplanes, 1),
            nn.BatchNorm2d(outplanes),
            nn.ReLU(True))

    def forward(self, x):
        # head
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.res1(x)
        x = self.pool(x)
        x = self.res2(x)
        x = self.res3(x)

        out = []
        for i in range(self.nStacks):
            y = self.hg[i](x)
            y = self.res[i](y)
            y = self.fc[i](y)
            score = self.score[i](y)
            out.append(score)
            if i < (self.nStacks - 1):
                fc_ = self.fc_[i](y)
                score_ = self.score_[i](score)
                x = x + fc_ + score_

        return out

三、訓練

        初始化數據 和 網絡

train_set = MPII_Dataset(
    FLAGS.dataDir, split='train',
    inp_res=FLAGS.inputRes, out_res=FLAGS.outputRes,
    scale_factor=FLAGS.scale, rot_factor=FLAGS.rotate, sigma=FLAGS.hmSigma)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=FLAGS.trainBatch, shuffle=True,
    num_workers=FLAGS.nThreads, pin_memory=True)
netHg = nn.DataParallel(HourglassNet(
    nStacks=FLAGS.nStacks, nModules=FLAGS.nModules, nFeat=FLAGS.nFeats,
    nClasses=train_set.nJoints))  # ref `nClasses` from dataset
criterion = nn.MSELoss()
if cuda:
    torch.backends.cudnn.benchmark = True
    netHg.cuda()
    criterion.cuda()

optimHg = torch.optim.RMSprop(
    netHg.parameters(),
    lr=FLAGS.lr,
    alpha=FLAGS.alpha, eps=FLAGS.eps)

       調用網絡進行訓練

def run(epoch, iter_start=0):
    netHg.train()

    global global_step
    pbar = tqdm.tqdm(train_loader, desc='Epoch %02d' % epoch, dynamic_ncols=True)
    pbar_info = tqdm.tqdm(bar_format='{bar}{postfix}')
    avg_acc = 0
    for it, sample in enumerate(pbar, start=iter_start):
        global_step += 1
        image, label, image_s = sample
        image = Variable(image)
        label = Variable(label)
        image_s = Variable(image_s)
        if FLAGS.cuda:
            image = image.cuda(async=True)  # TODO: check the affect of async
            label = label.cuda(async=True)
            image_s = image_s.cuda(async=True)

        # generator
        outputs = netHg(image)
        loss_hg_content = 0
        for out in outputs:  # TODO: speed up with multiprocessing map?
            loss_hg_content += criterion(out, label)

        loss_hg = loss_hg_content

        optimHg.zero_grad()
        loss_hg.backward()
        optimHg.step()

        accs = accuracy(outputs[-1].data.cpu(), label.data.cpu(), train_set.accIdxs)

        sumWriter.add_scalar('loss_hg', loss_hg, global_step)
        sumWriter.add_scalar('acc', accs[0], global_step)
        # TODO: learning rate scheduling
        # sumWriter.add_scalar('lr', lr, global_step)

        pbar_info.set_postfix({
            'loss_hg': getValue(loss_hg),
            'acc': accs[0]
        })
        pbar_info.update()
        avg_acc += accs[0] / len(train_loader)

    pbar_info.set_postfix_str('avg_acc: {}'.format(avg_acc))
    pbar.close()
    pbar_info.close()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章