Pytorch|YOWO原理及代碼詳解(三)
本博客上接,
Pytorch|YOWO原理及代碼詳解(一),
Pytorch|YOWO原理及代碼詳解(二),閱前可看。
1. test分析
if opt.evaluate:
logging('evaluating ...')
test(0)
else:
for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
# Train the model for 1 epoch
train(epoch)
# Validate the model
fscore = test(epoch)
is_best = fscore > best_fscore
if is_best:
print("New best fscore is achieved: ", fscore)
print("Previous fscore was: ", best_fscore)
best_fscore = fscore
# Save the model to backup directory
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'fscore': fscore
}
save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
logging('Weights are saved to backup directory: %s' % (backupdir))
上一節把train的整個流程分析完畢,本節主要分析test流程:fscore = test(epoch)
,進入(step into),查看完整的代碼如下:
def test(epoch):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
test_loader = torch.utils.data.DataLoader(
dataset.listDataset(basepath, testlist, dataset_use=dataset_use, shape=(init_width, init_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor()
]), train=False),
batch_size=batch_size, shuffle=False, **kwargs)
num_classes = region_loss.num_classes
anchors = region_loss.anchors
num_anchors = region_loss.num_anchors
conf_thresh_valid = 0.005
total = 0.0
proposals = 0.0
correct = 0.0
fscore = 0.0
correct_classification = 0.0
total_detected = 0.0
nbatch = file_lines(testlist) // batch_size
logging('validation at epoch %d' % (epoch))
model.eval()
for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
if use_cuda:
data = data.cuda()
with torch.no_grad():
output = model(data).data
all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
for i in range(output.size(0)):
boxes = all_boxes[i]
boxes = nms(boxes, nms_thresh)
if dataset_use == 'ucf101-24':
detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
if not os.path.exists('ucf_detections'):
os.makedirs(current_dir)
if not os.path.exists(current_dir):
os.makedirs(current_dir)
else:
detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
if not os.path.exists('jhmdb_detections'):
os.mkdir(current_dir)
if not os.path.exists(current_dir):
os.mkdir(current_dir)
with open(detection_path, 'w+') as f_detect:
for box in boxes:
x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
y2 = round(float(box[1] + box[3] / 2.0) * 240.0)
det_conf = float(box[4])
for j in range((len(box) - 5) // 2):
cls_conf = float(box[5 + 2 * j].item())
if type(box[6 + 2 * j]) == torch.Tensor:
cls_id = int(box[6 + 2 * j].item())
else:
cls_id = int(box[6 + 2 * j])
prob = det_conf * cls_conf
f_detect.write(
str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(
x2) + ' ' + str(y2) + '\n')
truths = target[i].view(-1, 5)
num_gts = truths_length(truths)
total = total + num_gts
for i in range(len(boxes)):
if boxes[i][4] > 0.25:
proposals = proposals + 1
for i in range(num_gts):
box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
best_iou = 0
best_j = -1
for j in range(len(boxes)):
iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
if iou > best_iou:
best_j = j
best_iou = iou
if best_iou > iou_thresh:
total_detected += 1
if int(boxes[best_j][6]) == box_gt[6]:
correct_classification += 1
if best_iou > iou_thresh and int(boxes[best_j][6]) == box_gt[6]:
correct = correct + 1
precision = 1.0 * correct / (proposals + eps)
recall = 1.0 * correct / (total + eps)
fscore = 2.0 * precision * recall / (precision + recall + eps)
logging(
"[%d/%d] precision: %f, recall: %f, fscore: %f" % (batch_idx, nbatch, precision, recall, fscore))
classification_accuracy = 1.0 * correct_classification / (total_detected + eps)
locolization_recall = 1.0 * total_detected / (total + eps)
print("Classification accuracy: %.3f" % classification_accuracy)
print("Locolization recall: %.3f" % locolization_recall)
return fscore
test的主要任務是返回一個fscore,模型根據這個fscore,保存一個best 參數。
2. 加載數據集
和train一樣的,test需要加載數據集,使用的是listDataset類,和train中使用的listDataset類一致,只不過在加載時模式不一樣,存在差異。進入(step into)listDataset類查看(__init__
是一致的,只不過__getitem__
不同):
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
self.shape = (224, 224)
if self.train: # For Training
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
clip, label = load_data_detection(self.base_path, imgpath, self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
else: # For Testing
frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use)
clip = [img.resize(self.shape) for img in clip]
if self.transform is not None:
clip = [self.transform(img) for img in clip]
# (self.duration, -1) + self.shape = (8, -1, 224, 224)
clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
if self.train:
return (clip, label)
else:
return (frame_idx, clip, label)
定位到代碼:frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use)
,並進入(step into)
def load_data_detection(base_path, imgpath, train, train_dur, shape, dataset_use='ucf101-24', jitter=0.2, hue=0.1, saturation=1.5, exposure=1.5):
......
for i in reversed(range(train_dur)):
# make it as a loop
i_temp = im_ind - i * d
while i_temp < 1:
i_temp = max_num + i_temp
while i_temp > max_num:
i_temp = i_temp - max_num
if dataset_use == 'ucf101-24':
path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.jpg'.format(i_temp))
else:
path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.png'.format(i_temp))
clip.append(Image.open(path_tmp).convert('RGB'))
if train: # Apply augmentation
......
else: # No augmentation
label = torch.zeros(50*5)
try:
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32'))
except Exception:
tmp = torch.zeros(1,5)
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*5:
label = tmp[0:50*5]
elif tsz > 0:
label[0:tsz] = tmp
if train:
return clip, label
else:
return im_split[0] + '_' +im_split[1] + '_' + im_split[2], clip, label
和在train中一樣,clip是連續採樣的圖像序列,但是在test中是不需要數據增強的。從labelpath中讀取label:tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32'))
,查看完整代碼如下:
def read_truths_args(lab_path, min_box_scale):
truths = read_truths(lab_path)
new_truths = []
for i in range(truths.shape[0]):
cx = (truths[i][1] + truths[i][3]) / (2 * 320)
cy = (truths[i][2] + truths[i][4]) / (2 * 240)
imgw = (truths[i][3] - truths[i][1]) / 320
imgh = (truths[i][4] - truths[i][2]) / 240
truths[i][0] = truths[i][0] - 1
truths[i][1] = cx
truths[i][2] = cy
truths[i][3] = imgw
truths[i][4] = imgh
if truths[i][3] < min_box_scale:
continue
new_truths.append([truths[i][0], truths[i][1], truths[i][2], truths[i][3], truths[i][4]])
return np.array(new_truths)
cx和cy是target的中心座標,並歸一化到0-1之間,同理imgw和imgh則是target的寬和高,也歸一化到0-1之間,如果target的寬太小if truths[i][3] < min_box_scale:
,則省略,把轉換過後符合要求的target放入new_truths中。這裏可以看到數據集的寬和高都是固定的,如果使用其他的數據集,這裏是需要修改的。
繼續返回到listDataset類的__getitem__
中,tmp是獲取的標籤,tsz則是計算tmp中的元素數量。由於label = torch.zeros(50*5)
,label只能保存50個target,所有會對tsz進行判斷,如果tsz太大超出label可以保存的範圍,則指保留tmp前50個標籤,反之全部保存到label中。最後通過load_data_detection返回frame_idx, clip, label:
和train一樣,clip被拼接成shape爲大小的張量:clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3)
,繼續返回到test(epoch)
中:
......
logging('validation at epoch %d' % (epoch))
model.eval()
for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
if use_cuda:
data = data.cuda()
with torch.no_grad():
output = model(data).data
all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
for i in range(output.size(0)):
boxes = all_boxes[i]
boxes = nms(boxes, nms_thresh)
if dataset_use == 'ucf101-24':
detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
if not os.path.exists('ucf_detections'):
os.makedirs(current_dir)
if not os.path.exists(current_dir):
os.makedirs(current_dir)
else:
detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
if not os.path.exists('jhmdb_detections'):
os.mkdir(current_dir)
if not os.path.exists(current_dir):
os.mkdir(current_dir)
......
3.模型輸出
加載完數據集之後,則獲取數據集並計算。得到模型的輸出:output = model(data).data
,其shape爲,下面則是獲取所有的boxes:all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
,進入,查看完整代碼:
def get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, validation=False):
anchor_step = len(anchors)//num_anchors
if output.dim() == 3:
output = output.unsqueeze(0)
batch = output.size(0)
assert(output.size(1) == (5+num_classes)*num_anchors)
h = output.size(2)
w = output.size(3)
t0 = time.time()
all_boxes = []
output = output.view(batch*num_anchors, 5+num_classes, h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w)
grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
xs = torch.sigmoid(output[0]) + grid_x
ys = torch.sigmoid(output[1]) + grid_y
anchor_w = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([0]))
anchor_h = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([1]))
anchor_w = anchor_w.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
anchor_h = anchor_h.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
ws = torch.exp(output[2]) * anchor_w
hs = torch.exp(output[3]) * anchor_h
det_confs = torch.sigmoid(output[4])
cls_confs = torch.nn.Softmax()(Variable(output[5:5+num_classes].transpose(0,1))).data
cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
cls_max_confs = cls_max_confs.view(-1)
cls_max_ids = cls_max_ids.view(-1)
t1 = time.time()
sz_hw = h*w
sz_hwa = sz_hw*num_anchors
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
xs = convert2cpu(xs)
ys = convert2cpu(ys)
ws = convert2cpu(ws)
hs = convert2cpu(hs)
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
t2 = time.time()
for b in range(batch):
boxes = []
for cy in range(h):
for cx in range(w):
for i in range(num_anchors):
ind = b*sz_hwa + i*sz_hw + cy*w + cx
det_conf = det_confs[ind]
if only_objectness:
conf = det_confs[ind]
else:
conf = det_confs[ind] * cls_max_confs[ind]
if conf > conf_thresh:
bcx = xs[ind]
bcy = ys[ind]
bw = ws[ind]
bh = hs[ind]
cls_max_conf = cls_max_confs[ind]
cls_max_id = cls_max_ids[ind]
box = [bcx/w, bcy/h, bw/w, bh/h, det_conf, cls_max_conf, cls_max_id]
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
boxes.append(box)
all_boxes.append(boxes)
t3 = time.time()
if False:
print('---------------------------------')
print('matrix computation : %f' % (t1-t0))
print(' gpu to cpu : %f' % (t2-t1))
print(' boxes filter : %f' % (t3-t2))
print('---------------------------------')
return all_boxes
將output的shape進行重塑:output = output.view(batch*num_anchors, 5+num_classes, h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w)
,重塑後的shape爲:,即和對應起來。grid_x和grid_y是對應anchor在feature map上的座標(左上角)索引,和在train中提到的一致,在Pytorch|YOWO原理及代碼詳解(二)中已講解過。加上座標索引,即偏移量,xs和ys是anchor在feature map上的絕對座標。
下面是獲取5個anchor的w和h:
anchor_w = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([0]))
anchor_h = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([1]))
並把每個anchor的w和h對輸出的980個anchor進行一一匹配:
anchor_w = anchor_w.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
anchor_h = anchor_h.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
output的shape爲,其中output[2]對應的是980個anchor預測框的w,output[2]的shape爲,同理,output[2]對應的是980個anchor預測框的h,其shape爲。那麼:
ws = torch.exp(output[2]) * anchor_w
hs = torch.exp(output[3]) * anchor_h
則是計算每個anchor預測框的w和h,和yolov2的邊界框預測公式:
完全對應起來。
det_confs獲取預測框的置信度,cls_confs則是使用了softmax層獲取每個類別的得分,其shape爲,即是980個anchor預測的24個類別得分。通過cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
計算每個anchor的類別預測的最大概率值和其索引,並同時重塑成shape爲的張量,以和輸出的anchor完全對應起來。
接下來則是把一些列張量放在CPU上:
......
sz_hw = h*w
sz_hwa = sz_hw*num_anchors
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
xs = convert2cpu(xs)
ys = convert2cpu(ys)
ws = convert2cpu(ws)
hs = convert2cpu(hs)
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
......
接下來則是按batch進行計算,對應batch中的一份數據集,先按feature map的x,y進行遍歷,即獲得每個grid cell的偏移量(左上角座標)。ind = b*sz_hwa + i*sz_hw + cy*w + cx
,ind則是獲取對應anchor的索引,即batch中的第b份數據,第i個anchor在cy,cx上的索引。batch中的一份數據有sz_hwa(245)個anchor,一張圖被分爲了sz_hw(49) 個grid cell,每個grid cell都有num_anchors(5)個anchor。
if only_objectness:
conf = det_confs[ind]
else:
conf = det_confs[ind] * cls_max_confs[ind]
訓練過程的test模式(only_objectness = 0時),其置信度是等於det_confs * cls_max_confs。如果置信度conf是大於閾值的,獲取對應的值,放入box中:box = [bcx/w, bcy/h, bw/w, bh/h, det_conf, cls_max_conf, cls_max_id]
。下面這段代碼則是把預測的24個類別的概率和對應的ID全部放進box中,遍歷完之後,box的長度爲53()。
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
最後把一個batch的所有符合條件的boxes放入all_boxes中:all_boxes.append(boxes)
。接下來遍歷一個batch的所有boxes:boxes = all_boxes[i]
,然後使用非極大值抑制抑制得到最後的輸出:boxes = nms(boxes, nms_thresh)
,完整代碼如下:
def nms(boxes, nms_thresh):
if len(boxes) == 0:
return boxes
det_confs = torch.zeros(len(boxes))
for i in range(len(boxes)):
det_confs[i] = 1-boxes[i][4]
_,sortIds = torch.sort(det_confs)
out_boxes = []
for i in range(len(boxes)):
box_i = boxes[sortIds[i]]
if box_i[4] > 0:
out_boxes.append(box_i)
for j in range(i+1, len(boxes)):
box_j = boxes[sortIds[j]]
if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
#print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
box_j[4] = 0
return out_boxes
det_confs存儲檢測出的box的置信度值。由於torch.sort
默認是升序排列,但需要的是置信度得分高的在前面,所以對det_confs進行det_confs[i] = 1-boxes[i][4]
處理。sortIds則是排序後的anchor索引。非極大值抑制的核心是將重合的邊界框進行融合,如何判斷邊界框與邊界框的重合程度呢?通過計算IOU值:bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh
,當大於閾值時,則會省略多餘的box_j。說實話,我絕對這裏的NMS是存在問題的,1.沒有設置x1y1x2y2=True
,應該是需要考慮座標偏移的,而不是僅僅計算邊界框與邊框面積的重合程度;2. 沒有和過濾掉的邊界框進行融合。當然NMS是存在多種算法的,有興趣的可以自己再去了解一下。
接下來則是創建檢測的路徑:
if dataset_use == 'ucf101-24':
detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
if not os.path.exists('ucf_detections'):
os.makedirs(current_dir)
if not os.path.exists(current_dir):
os.makedirs(current_dir)
打開檢測文件:with open(detection_path, 'w+') as f_detect:
。獲取一個box檢測到的所有類別的得分:cls_conf = float(box[5 + 2 * j].item())
。計算概率值:prob = det_conf * cls_conf
,即置信度乘以類別得分,並把所有結構寫入到txt中:
truths則是獲取對應圖片的標註:truths = target[i].view(-1, 5)
,其shape爲。通過truths_length
獲取真實的target數量。
for i in range(len(boxes)):
if boxes[i][4] > 0.25:
proposals = proposals + 1
把置信度大於閾值 0.25的boxe作爲候選框proposals。把truths放入box_gt :box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
。遍歷檢測出的boxes,並計算iou值,得到best iou值和對應的box ID:
for j in range(len(boxes)):
iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
if iou > best_iou:
best_j = j
best_iou = iou
如果IOU值大於閾值,則認爲正確檢測到了一個,如果類別還預測對了,則爲真陽性:
if best_iou > iou_thresh:
total_detected += 1
if int(boxes[best_j][6]) == box_gt[6]:
correct_classification += 1
if best_iou > iou_thresh and int(boxes[best_j][6]) == box_gt[6]:
correct = correct + 1
剩下的就是各項指標計算:
......
precision = 1.0 * correct / (proposals + eps)
recall = 1.0 * correct / (total + eps)
fscore = 2.0 * precision * recall / (precision + recall + eps)
logging(
"[%d/%d] precision: %f, recall: %f, fscore: %f" % (batch_idx, nbatch, precision, recall, fscore))
classification_accuracy = 1.0 * correct_classification / (total_detected + eps)
locolization_recall = 1.0 * total_detected / (total + eps)
print("Classification accuracy: %.3f" % classification_accuracy)
print("Locolization recall: %.3f" % locolization_recall)
return fscore
當計算完,返回主程序,根據fscore來保存best模型。
if is_best:
print("New best fscore is achieved: ", fscore)
print("Previous fscore was: ", best_fscore)
best_fscore = fscore
# Save the model to backup directory
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'fscore': fscore
}
save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
logging('Weights are saved to backup directory: %s' % (backupdir))
test流程基本分析完畢,剩下的請見:
Pytorch|YOWO原理及代碼詳解(四)(待更新)