【代碼】CenterNet使用(續)(對五六七部分詳解)(五)

接上面部分,對五六七部分進行詳解,這篇介紹第五部分,也就是model從建立到測試,數據從images到output、dets的詳細過程。

一、回顧

第五部分放入網絡中測試,產生輸出:

      output, dets, forward_time = self.process(images, return_time=True)

process部分在ctdet.py中:

  def process(self, images, return_time=False):
    with torch.no_grad():
      output = self.model(images)[-1]
      hm = output['hm'].sigmoid_()
      wh = output['wh']
      reg = output['reg'] if self.opt.reg_offset else None
      if self.opt.flip_test:
        hm = (hm[0:1] + flip_tensor(hm[1:2])) / 2
        wh = (wh[0:1] + flip_tensor(wh[1:2])) / 2
        reg = reg[0:1] if reg is not None else None
      torch.cuda.synchronize()
      forward_time = time.time()
      dets = ctdet_decode(hm, wh, reg=reg, K=self.opt.K)
      
    if return_time:
      return output, dets, forward_time
    else:
      return output, dets

首先將images放入model中,就得到output了。output具有三個部分

{'hm': 1*80*128*128,

'reg': 1*2*128*128,

'wh': 1*2*128*128},可以看出來,只有hm(heatmap)是與類別(80個)相關的,reg(offset:x_off & y_off)和wh(width & height)是與類別無關的。

之後使用ctdet_decode進行解碼,得到dets,dets是1*100*6的張量。

最終,返回outputs,dets,forward_time。

二、詳解

分爲兩個部分,第一個部分是images放入model中,得到output

第二個部分是ctdet_decode解碼。

1.self.model(images)[-1]

1.seld.model的建立

在BaseDetector中:

    self.model = create_model(opt.arch, opt.heads, opt.head_conv)
    self.model = load_model(self.model, opt.load_model) 

 涉及到的兩個函數來源於models.model。

1. create_model

兩行主要的代碼如下:

  get_model = _model_factory[arch]
  model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)

產生的中間變量的結果:,arch用來獲得get_model,在demo中,獲得的是networks中的pose_dla_dcn的get_pose_net函數,其定義爲:

def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
  model = DLASeg('dla{}'.format(num_layers), heads,
                 pretrained=True,
                 down_ratio=down_ratio,
                 final_kernel=1,
                 last_level=5,
                 head_conv=head_conv)
  return model

 是DLANet,可能是來自於這種網絡結構:https://blog.csdn.net/wuyubinbin/article/details/80622762

2. load_model,用於加載預訓練模型(待看):

def load_model(model, model_path, optimizer=None, resume=False, 
               lr=None, lr_step=None):
  start_epoch = 0
  checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
  print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
  state_dict_ = checkpoint['state_dict']
  state_dict = {}
  
  # convert data_parallal to model
  for k in state_dict_:
    if k.startswith('module') and not k.startswith('module_list'):
      state_dict[k[7:]] = state_dict_[k]
    else:
      state_dict[k] = state_dict_[k]
  model_state_dict = model.state_dict()

  # check loaded parameters and created model parameters
  for k in state_dict:
    if k in model_state_dict:
      if state_dict[k].shape != model_state_dict[k].shape:
        print('Skip loading parameter {}, required shape{}, '\
              'loaded shape{}.'.format(
          k, model_state_dict[k].shape, state_dict[k].shape))
        state_dict[k] = model_state_dict[k]
    else:
      print('Drop parameter {}.'.format(k))
  for k in model_state_dict:
    if not (k in state_dict):
      print('No param {}.'.format(k))
      state_dict[k] = model_state_dict[k]
  model.load_state_dict(state_dict, strict=False)

  # resume optimizer parameters
  if optimizer is not None and resume:
    if 'optimizer' in checkpoint:
      optimizer.load_state_dict(checkpoint['optimizer'])
      start_epoch = checkpoint['epoch']
      start_lr = lr
      for step in lr_step:
        if start_epoch >= step:
          start_lr *= 0.1
      for param_group in optimizer.param_groups:
        param_group['lr'] = start_lr
      print('Resumed optimizer with start lr', start_lr)
    else:
      print('No optimizer parameters in checkpoint.')
  if optimizer is not None:
    return model, optimizer, start_epoch
  else:
    return model

2. model的forward()

    def forward(self, x):
# x = 1*3*512*512
        x = self.base(x)
# x 是六個元素的list = 1* [16*512*512, 32*256*256, 64*128*128, 128*64*64, 256*32*32, 512*16*16]

        x = self.dla_up(x)
# y = 1* [64*128*128, 128*64*64, 256*32*32]
        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))
# y = 1* [64*128*128, 64*128*128, 64*128*128]
        z = {}
        for head in self.heads:
            z[head] = self.__getattr__(head)(y[-1])
# z = {'hm' : 1*80*128*128,
       'reg' : 1*2*128*128,
       'wh' : 1*2*128*128}
        return [z]

 2. ctdet_decode

ctdet_decode在models.decode中,最終產生的detections是bboxes、scores、clses的合併:
    detections = torch.cat([bboxes, scores, clses], dim=2)

 其中bboxes是左上角,右下角的形式,是1*100*4的FloatTensor。scores是1*100*1的FloatTensor的[0, 1]內的Tensor,其按照降序排列。clses也是1*100*1的Tensor,均是整數,代表類別。具體的解碼過程可以參照之前的:

https://mp.csdn.net/postedit/91955759

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章