Boundary Aware PoolNet(2):BASNet模型與代碼介紹

Boundary Aware PoolNet = PoolNet + BASNet,即使用BASNet中的Deep Supervision策略和Hybrid Loss改進PoolNet。

爲理解Boundary Aware PoolNet,我們並不需要學習整個BASNet,只需要瞭解其中的Deep Supervision策略和Hybrid Loss即可。

本文將簡單介紹BASNet的模型結構,重點介紹其Deep SupervisionHybrid Loss的理論和代碼實現。

相關文章彙總:

BASNet

傳送門

BASNet結構

img

如上圖所示,BASNet模型包括Predict Module和Refine Module。

  • Predict Module

    a U-Net-like densely supervised Encoder-Decoder network,作用是predict saliency map from input images。

    其實這個Encoder-Decoder結構和FPN(特徵金字塔網絡)沒什麼區別吧。

  • Refine Module

    refines the resulting saliency map of the prediction module by learning the residuals between the saliency map and the ground truth。

基於上述的2個Module,BASNet使用Deep Supervision(上圖中的Sup1-8)和Hybrid Loss進行模型訓練。

代碼

Predict Module的代碼在文件./model/BASNet.py中類BASNet中,Refine Module的代碼在文件./model/BASNet.py中類RefUnet中。

Deep Supervision

直白來講,Deep Supervision即使用神經網絡中多個層的Loss之和進行梯度下降。

如前文中BASNet結構圖所示,BASNet作者計算了Predict Module中的7層和Refine Module中的最後1層的Loss並進行求和,然後進行梯度下降,以此實現Deep Supervision。在計算邊路輸出時,需要進行上採樣和卷積使得邊路輸出的尺寸、通道數與輸入相同。

Deep Supervision的代碼在文件./model/BASNet.py的類BASNet的函數forward()中,可知類BASNetforward()時返回了8個邊路輸出,後繼計算這8層的Hybrid Loss並求和進行梯度下降。

Hybrid Loss

直白來講,Hybrid Loss即在計算損失時使用BCE Loss、SSIM Loss、IOU Loss這3個損失之和而非只使用BCE損失函數。

Hybrid Loss的代碼在文件./basnet_trin.py中的函數muti_bce_loss_fusion()中。該函數的輸入爲BASNet的8個邊路輸出和輸入對應的標註,該函數使用函數bce_ssim_loss()計算1個邊路輸出與標註的3種Loss之和。


Github(github.com):@chouxianyu

Github Pages(github.io):@臭鹹魚

知乎(zhihu.com):@臭鹹魚

博客園(cnblogs.com):@臭鹹魚

B站(bilibili.com):@絕版臭鹹魚

微信公衆號:@臭鹹魚

轉載請註明出處,歡迎討論和交流!


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章