源碼地址:https://github.com/amdegroot/ssd.pytorch
#layers\box_utils.py
def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
"""Match each prior box with the ground truth box of the highest jaccard
overlap, encode the bounding boxes, then return the matched indices
corresponding to both confidence and location preds.
Args:
threshold: (float) The overlap threshold used when mathing boxes.
truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
variances: (tensor) Variances corresponding to each prior coord,
Shape: [num_priors, 4].
labels: (tensor) All the class labels for the image, Shape: [num_obj].
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
idx: (int) current batch index
Return:
The matched indices corresponding to 1)location and 2)confidence preds.
"""
# jaccard index
overlaps = jaccard(truths,point_form(priors))
# (Bipartite Matching)
# [1,num_objects] best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# [1,num_priors] best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_overlap.squeeze_(1)
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)):
best_truth_idx[best_prior_idx[j]] = j
matches = truths[best_truth_idx] # Shape: [num_priors,4]
conf = labels[best_truth_idx] + 1 # Shape: [num_priors]
conf[best_truth_overlap < threshold] = 0 # label as background
loc = encode(matches, priors, variances)
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
conf_t[idx] = conf # [num_priors] top class label for each prior
overlaps = jaccard(truths,point_form(priors))
獲取truths和priors的交併比overlaps,overlaps是一個二維數組,元素值即IOU大小,若取truths和priors的集合長度皆爲4,則可如下圖所示:
# [1,num_objects] best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
獲取每一個truth在Priors集合中最大IOU的值以及下標,比如truth[1]對應在Priors中對應最大iou下標爲2,即overlaps[1,2]
# [1,num_priors] best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
獲取每個prior在Truths集合中最大IOU的值以及下標,比如prior[2]對應在truth中對應最大iou的下標爲2,即overlaps[2,3]
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_overlap.squeeze_(1)
上面4句都是把維度壓成一維,其實設置參數keepdim=False,不是就有上面的效果嘛?
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)):
best_truth_idx[best_prior_idx[j]] = j
註釋也說了,這2句都是爲了確保最佳的prior?
首先需明確:在訓練時,一個prior只能對應一個匹配的truth。先看下面的例子,truth[1]想在Priors中找到一個最大的iou,剛好找到了prior[2],此刻他們牽手成功(best_prior_idx[1]=2),但後來prior[2]發現true[3]更加漂亮(overlaps高達0.6),果斷拋棄了truth[1] (best_truth_idx[2]=3),最終導致truth[1]沒有與任何prior匹配,孤獨終生。
在回過來看看for循環,其功能正是勸prior[2]浪子回頭(best_truth_idx[2]=1)
best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
雖然prior[2]已經浪子回頭,但是他們的IOU=0.3實在還是太小,不免會小於threshold(0.5)而被認爲負樣本。那索性就提高他們的IOU,讓其大於threshold。