接上面部分,對五六七部分進行詳解,這篇介紹第五部分,也就是model從建立到測試,數據從images到output、dets的詳細過程。
一、回顧
第五部分放入網絡中測試,產生輸出:
output, dets, forward_time = self.process(images, return_time=True)
process部分在ctdet.py中:
def process(self, images, return_time=False):
with torch.no_grad():
output = self.model(images)[-1]
hm = output['hm'].sigmoid_()
wh = output['wh']
reg = output['reg'] if self.opt.reg_offset else None
if self.opt.flip_test:
hm = (hm[0:1] + flip_tensor(hm[1:2])) / 2
wh = (wh[0:1] + flip_tensor(wh[1:2])) / 2
reg = reg[0:1] if reg is not None else None
torch.cuda.synchronize()
forward_time = time.time()
dets = ctdet_decode(hm, wh, reg=reg, K=self.opt.K)
if return_time:
return output, dets, forward_time
else:
return output, dets
首先將images放入model中,就得到output了。output具有三個部分
{'hm': 1*80*128*128,
'reg': 1*2*128*128,
'wh': 1*2*128*128},可以看出來,只有hm(heatmap)是與類別(80個)相關的,reg(offset:x_off & y_off)和wh(width & height)是與類別無關的。
之後使用ctdet_decode進行解碼,得到dets,dets是1*100*6的張量。
最終,返回outputs,dets,forward_time。
二、詳解
分爲兩個部分,第一個部分是images放入model中,得到output
第二個部分是ctdet_decode解碼。
1.self.model(images)[-1]
1.seld.model的建立
在BaseDetector中:
self.model = create_model(opt.arch, opt.heads, opt.head_conv)
self.model = load_model(self.model, opt.load_model)
涉及到的兩個函數來源於models.model。
1. create_model
兩行主要的代碼如下:
get_model = _model_factory[arch]
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
產生的中間變量的結果:,arch用來獲得get_model,在demo中,獲得的是networks中的pose_dla_dcn的get_pose_net函數,其定義爲:
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads,
pretrained=True,
down_ratio=down_ratio,
final_kernel=1,
last_level=5,
head_conv=head_conv)
return model
是DLANet,可能是來自於這種網絡結構:https://blog.csdn.net/wuyubinbin/article/details/80622762
2. load_model,用於加載預訓練模型(待看):
def load_model(model, model_path, optimizer=None, resume=False,
lr=None, lr_step=None):
start_epoch = 0
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))
state_dict_ = checkpoint['state_dict']
state_dict = {}
# convert data_parallal to model
for k in state_dict_:
if k.startswith('module') and not k.startswith('module_list'):
state_dict[k[7:]] = state_dict_[k]
else:
state_dict[k] = state_dict_[k]
model_state_dict = model.state_dict()
# check loaded parameters and created model parameters
for k in state_dict:
if k in model_state_dict:
if state_dict[k].shape != model_state_dict[k].shape:
print('Skip loading parameter {}, required shape{}, '\
'loaded shape{}.'.format(
k, model_state_dict[k].shape, state_dict[k].shape))
state_dict[k] = model_state_dict[k]
else:
print('Drop parameter {}.'.format(k))
for k in model_state_dict:
if not (k in state_dict):
print('No param {}.'.format(k))
state_dict[k] = model_state_dict[k]
model.load_state_dict(state_dict, strict=False)
# resume optimizer parameters
if optimizer is not None and resume:
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
start_lr = lr
for step in lr_step:
if start_epoch >= step:
start_lr *= 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = start_lr
print('Resumed optimizer with start lr', start_lr)
else:
print('No optimizer parameters in checkpoint.')
if optimizer is not None:
return model, optimizer, start_epoch
else:
return model
2. model的forward()
def forward(self, x):
# x = 1*3*512*512
x = self.base(x)
# x 是六個元素的list = 1* [16*512*512, 32*256*256, 64*128*128, 128*64*64, 256*32*32, 512*16*16]
x = self.dla_up(x)
# y = 1* [64*128*128, 128*64*64, 256*32*32]
y = []
for i in range(self.last_level - self.first_level):
y.append(x[i].clone())
self.ida_up(y, 0, len(y))
# y = 1* [64*128*128, 64*128*128, 64*128*128]
z = {}
for head in self.heads:
z[head] = self.__getattr__(head)(y[-1])
# z = {'hm' : 1*80*128*128,
'reg' : 1*2*128*128,
'wh' : 1*2*128*128}
return [z]
2. ctdet_decode
ctdet_decode在models.decode中,最終產生的detections是bboxes、scores、clses的合併:
detections = torch.cat([bboxes, scores, clses], dim=2)
其中bboxes是左上角,右下角的形式,是1*100*4的FloatTensor。scores是1*100*1的FloatTensor的[0, 1]內的Tensor,其按照降序排列。clses也是1*100*1的Tensor,均是整數,代表類別。具體的解碼過程可以參照之前的:
https://mp.csdn.net/postedit/91955759