今天大致看了一下CornerNet的代碼,對其中的關鍵代碼做一些整理。
由於CenterNet(CenterNet:Keypoint Triplets for Object Detection)是在CornerNet的基礎上修改來的,所以基本是一致的
cornernet的主要結構基本都定義在./models/py_utils文件夾下,主幹結構定義在./models/py_utils/kp.py這個文件夾內,部分結構也在kp_utils.py中實現,corner pooling在_cpools文件夾下使用c++語言實現。
接下來我主要總結了一下,網絡的結構定義class kp(nn.Module),_decode()函數以及,corner pooling中的left pooling函數。
首先,我來介紹一下網絡的定義 class kp(nn.Module),其在kp.py文件中定義,該類主要實現了網絡的整體結構,以及train和test的前向的過程,可以說是本網絡的精髓所在。
class kp(nn.Module):
def __init__(
self, n, nstack, dims, modules, out_dim, pre=None, cnv_dim=256,
make_tl_layer=make_tl_layer, make_br_layer=make_br_layer,
make_cnv_layer=make_cnv_layer, make_heat_layer=make_kp_layer,
make_tag_layer=make_kp_layer, make_regr_layer=make_kp_layer,
make_up_layer=make_layer, make_low_layer=make_layer,
make_hg_layer=make_layer, make_hg_layer_revr=make_layer_revr,
make_pool_layer=make_pool_layer, make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer, make_inter_layer=make_inter_layer,
kp_layer=residual
):
super(kp, self).__init__()
## nstack是一個最開始我也沒弄懂是幹嘛的函數,後來突然想起來論文中給了intermediate supervision的介紹,才知道這個其實是實現的這個,翻譯爲中繼監督,後面會有介紹,這個在作者的代碼中默認取的是2
self.nstack = nstack
## decode就是網絡輸出了heatmap,embedding,offset後如何進行點匹配以及最終選擇哪些點對作爲結果的函數,這個類介紹完會去介紹那個函數。
self._decode = _decode
curr_dim = dims[0]
## self.pre定義的是網絡的頭部,網絡先接了一個kernel size 7x7的conv以及一個residual結構
self.pre = nn.Sequential(
convolution(7, 3, 128, stride=2),
residual(3, 128, 256, stride=2)
) if pre is None else pre
### CornerNet的主幹結構是hourglasses,這裏是就是其主幹結構,make_xx_layer都是定義在kp_utils.py文件中的,感興趣可以看一下,這裏不詳細介紹了,知道其實hourglasses主幹結構就可以了。**並且注意到了嗎,這裏的定義都使用了for循環 for _ in range(nstack),其實作者所有的結構都定義了兩個,兩個結構通過前面提到的中繼監督連接到一起。**
self.kps = nn.ModuleList([
kp_module(
n, dims, modules, layer=kp_layer,
make_up_layer=make_up_layer,
make_low_layer=make_low_layer,
make_hg_layer=make_hg_layer,
make_hg_layer_revr=make_hg_layer_revr,
make_pool_layer=make_pool_layer,
make_unpool_layer=make_unpool_layer,
make_merge_layer=make_merge_layer
) for _ in range(nstack)
])
### hourglasses輸出後,接一個卷積層
self.cnvs = nn.ModuleList([
make_cnv_layer(curr_dim, cnv_dim) for _ in range(nstack)
])
## 然後定義的是接的兩個分支,分別去輸出top left 以及 bottom right的分支
self.tl_cnvs = nn.ModuleList([
make_tl_layer(cnv_dim) for _ in range(nstack)
])
self.br_cnvs = nn.ModuleList([
make_br_layer(cnv_dim) for _ in range(nstack)
])
## keypoint heatmaps ,用於輸出tl以及br的熱圖,這裏是8 * 256 *256的
self.tl_heats = nn.ModuleList([
make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
])
self.br_heats = nn.ModuleList([
make_heat_layer(cnv_dim, curr_dim, out_dim) for _ in range(nstack)
])
## tags ## 用於輸出 embeddings值 1 * 256 * 256的
self.tl_tags = nn.ModuleList([
make_tag_layer(cnv_dim, curr_dim, 1) for _ in range(nstack)
])
self.br_tags = nn.ModuleList([
make_tag_layer(cnv_dim, curr_dim, 1) for _ in range(nstack)
])
for tl_heat, br_heat in zip(self.tl_heats, self.br_heats):
tl_heat[-1].bias.data.fill_(-2.19)
br_heat[-1].bias.data.fill_(-2.19)
## 下面這三個其實是中繼結構,即將輸出再接入下一個輸入,後面的train以及test函數中會用到。
self.inters = nn.ModuleList([
make_inter_layer(curr_dim) for _ in range(nstack - 1)
])
self.inters_ = nn.ModuleList([
nn.Sequential(
nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False),
nn.BatchNorm2d(curr_dim)
) for _ in range(nstack - 1)
])
self.cnvs_ = nn.ModuleList([
nn.Sequential(
nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False),
nn.BatchNorm2d(curr_dim)
) for _ in range(nstack - 1)
])
### 這裏定義的是輸出的迴歸座標 : 2 * 256 * 256
self.tl_regrs = nn.ModuleList([
make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
])
self.br_regrs = nn.ModuleList([
make_regr_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)
])
self.relu = nn.ReLU(inplace=True)
def _train(self, *xs):
image = xs[0]
tl_inds = xs[1]
br_inds = xs[2]
## image 最先過的網絡,是7x7的卷積,數據表示爲inter,這裏注意一下這個inter,後面會用到。
inter = self.pre(image)
## 保存輸出的
outs = []
layers = zip(
self.kps, self.cnvs,
self.tl_cnvs, self.br_cnvs,
self.tl_heats, self.br_heats,
self.tl_tags, self.br_tags,
self.tl_regrs, self.br_regrs
)
## 這個for循環的意思就是對應的nstack。
for ind, layer in enumerate(layers):
kp_, cnv_ = layer[0:2]
tl_cnv_, br_cnv_ = layer[2:4]
tl_heat_, br_heat_ = layer[4:6]
tl_tag_, br_tag_ = layer[6:8]
tl_regr_, br_regr_ = layer[8:10]
#### 下面都沒什麼好說的,就是網絡一層一層的過。
kp = kp_(inter)
cnv = cnv_(kp)
tl_cnv = tl_cnv_(cnv)
br_cnv = br_cnv_(cnv)
tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv)
tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv)
tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv)
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
# 結果保存一下
outs += [tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr]
##這裏比較重要,這裏就是中繼結構的核心,還記得前面提到的inter嗎?這裏就是先將inter進行了self.inters_操作,然後將前面的輸出cnv(哪裏輸出的上面找),過一下self.cnvs_結構,然後對其進行求和,之後過了relu以及self.inters結構,最後作爲輸入進入到nstack==1的結構,在來一遍,其實self.inters_與self.cnvs_的結構是一樣的,都是卷積層。
if ind < self.nstack - 1:
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
inter = self.relu(inter)
inter = self.inters[ind](inter)
return outs
### test與train函數同理,唯一不同的是,train函數將nstack==0和nstack==1的輸出都放到了output中,而test只試講nstack==1的結果放到了output中這裏就不詳細介紹了。
def _test(self, *xs, **kwargs):
image = xs[0]
inter = self.pre(image)
outs = []
layers = zip(
self.kps, self.cnvs,
self.tl_cnvs, self.br_cnvs,
self.tl_heats, self.br_heats,
self.tl_tags, self.br_tags,
self.tl_regrs, self.br_regrs
)
for ind, layer in enumerate(layers):
kp_, cnv_ = layer[0:2]
tl_cnv_, br_cnv_ = layer[2:4]
tl_heat_, br_heat_ = layer[4:6]
tl_tag_, br_tag_ = layer[6:8]
tl_regr_, br_regr_ = layer[8:10]
kp = kp_(inter)
cnv = cnv_(kp)
if ind == self.nstack - 1:
tl_cnv = tl_cnv_(cnv)
br_cnv = br_cnv_(cnv)
tl_heat, br_heat = tl_heat_(tl_cnv), br_heat_(br_cnv)
tl_tag, br_tag = tl_tag_(tl_cnv), br_tag_(br_cnv)
tl_regr, br_regr = tl_regr_(tl_cnv), br_regr_(br_cnv)
outs += [tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr]
if ind < self.nstack - 1:
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
inter = self.relu(inter)
inter = self.inters[ind](inter)
return self._decode(*outs[-6:], **kwargs)
decode這個函數的作用是處理模型的輸出結果,利用(heatmap, emd,offset)的輸出,求出模型的檢測結果,下面介紹一下這個函數。
def _decode(
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
K=100, kernel=1, ae_threshold=1, num_dets=1000
):
batch, cat, height, width = tl_heat.size()
## 首先將top_left以及bottom right 利用sigmoid映射到0-1,
tl_heat = torch.sigmoid(tl_heat)
br_heat = torch.sigmoid(br_heat)
# perform nms on heatmaps 對其進行nms操作,其實就是maxpooling,保留max部分,kernel_size = 3 x 3。
tl_heat = _nms(tl_heat, kernel=kernel)
br_heat = _nms(br_heat, kernel=kernel)
## 在top left以及bottom right,找到最大的前K個點,並記錄下他們的得分,位置,類別,座標等信息,下面返回的結果分別代表的是:
## 類別得分,位置索引,類別,y座標,x座標
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
#下面是將座標擴充, 爲後面拿到所有的座標組合做準備。這裏擴充完之後變成了下面的樣子 左邊是橫向的擴充,右邊是縱向的擴充
#[1,1,1 [ 1,2,3,
# 2,2,2 1,2,3,
# 3,3,3] 1,2,3 ]
# 這樣就可以組合出所有的枚舉座標了。也就是下面乾的事情
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
#根據上面的索引,將offset拿出來。
if tl_regr is not None and br_regr is not None:
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
tl_regr = tl_regr.view(batch, K, 1, 2)
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
br_regr = br_regr.view(batch, 1, K, 2)
#更新座標,將熱圖求的座標跟offset做求和操作。
tl_xs = tl_xs + tl_regr[..., 0]
tl_ys = tl_ys + tl_regr[..., 1]
br_xs = br_xs + br_regr[..., 0]
br_ys = br_ys + br_regr[..., 1]
# all possible boxes based on top k corners (ignoring class)
## 這裏首先不考類別,暴利的求出左上角點和右下角點的所有的組合框,即每個左上角點都與右下角點組合
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
### 拿出所有的左上角點和右下角點的embedding的值,用於後面驗證距離,只有距離相近,才能被判斷爲是同一個類別
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
tl_tag = tl_tag.view(batch, K, 1)
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
br_tag = br_tag.view(batch, 1, K)
### 計算左上角點以及右下角點的距離的絕對值。
dists = torch.abs(tl_tag - br_tag)
#### 拿出所有的左上角和右下角的 類別得分
tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
##### 將所有的得分求平均
scores = (tl_scores + br_scores) / 2
# 由於前面是枚舉了所有可能的組合情況,所以肯定會有很多錯誤的匹配情況,這裏開始,根據一系列條件,幹掉錯誤的匹配情況。
# reject boxes based on classes 將左上角和右下角類別不同的幹掉
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
cls_inds = (tl_clses != br_clses)
# reject boxes based on distances 將距離大於閾值的幹掉,這裏是0.5
dist_inds = (dists > ae_threshold)
# reject boxes based on widths and heights 左上角不在右下角上方的幹掉
width_inds = (br_xs < tl_xs)
height_inds = (br_ys < tl_ys)
##將上面提到的全部幹掉
scores[cls_inds] = -1
scores[dist_inds] = -1
scores[width_inds] = -1
scores[height_inds] = -1
scores = scores.view(batch, -1)
### 拿到過濾後的topk的得分,以及topk的index
scores, inds = torch.topk(scores, num_dets)
scores = scores.unsqueeze(2)
##下面分別利用index過濾,拿到topkscore對應的座標以及類別等
bboxes = bboxes.view(batch, -1, 4)
bboxes = _gather_feat(bboxes, inds)
clses = tl_clses.contiguous().view(batch, -1, 1)
clses = _gather_feat(clses, inds).float()
tl_scores = tl_scores.contiguous().view(batch, -1, 1)
tl_scores = _gather_feat(tl_scores, inds).float()
br_scores = br_scores.contiguous().view(batch, -1, 1)
br_scores = _gather_feat(br_scores, inds).float()
##拼接到一起後返回
detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
return detections
Corner Pooling是用C++來完成的,這裏,這裏主要簡單介紹一下left pooling的做法,其他的同理,其實實現的就是下面這個過程
std::vector<at::Tensor> pool_forward(
at::Tensor input
) {
// Initialize output output的形狀跟input是一致的,所以先根據input構建出output
at::Tensor output = at::zeros_like(input);
// Get width 拿到長度
int64_t width = input.size(3);
// Copy the last column,left pooling是一行,從右往左進行的,所以最後一個的input的值和output的值是一致的,下面三行代碼就是實現複製的代碼。
at::Tensor input_temp = input.select(3, width - 1);
at::Tensor output_temp = output.select(3, width - 1);
output_temp.copy_(input_temp);
// 接下來就是從倒數第二個開始,逐個比較,永遠把最大的放到output當前的位置上。
at::Tensor max_temp;
for (int64_t ind = 1; ind < width; ++ind) {
input_temp = input.select(3, width - ind - 1);
output_temp = output.select(3, width - ind);
max_temp = output.select(3, width - ind - 1);
at::max_out(max_temp, input_temp, output_temp);
}
return {
output
};
}