本文是論文《Unsupervised Bidirectional Cross-Modality Adaptation via Deeply Synergistic Image and Feature Alignment for Medical Image Segmentation》的閱讀筆記。
文章提出了一個名爲 SIFA(Synergistic Image and Feature Alignment)的無監督域適應框架。SIFA 的代碼見 github 。SIFA 從圖像和特徵兩個角度引入了對齊的協同融合。
一、相關工作
域適應就是將從源域學習到的知識遷移到目標域中,在此之前 CycleGAN 在域適應方面取得了很好的效果。
SIFA 的一個關鍵特點是圖像變換和分割任務的共享編碼器。通過參數共享,本框架中的圖像對齊和特徵對齊能夠協同工作,減少端到端訓練過程中的域偏移(domain shift)。同時,另一個研究方向是特徵對齊,目的是在對抗性學習的情況下提取深度神經網絡的域不變特徵。
二、記號
s s s :源域
t t t :目標域
G t G_t G t :從源域到目標域的生成器,生成 x s → t x^{s\rightarrow t} x s → t
D t D_t D t :從源域到目標域的判別器,判別圖像是生成的還是真正來自目標域的
E E E :特徵編碼器
U U U :解碼器
C C C :像素級分類器
G s = E ∘ U G_s=E\circ U G s = E ∘ U :特徵編碼器+解碼器相當於一個源域生成器,生成 x t → s x^{t\rightarrow s} x t → s
E ∘ C E\circ C E ∘ C :特徵編碼器+像素級分類器相當於一個分割網絡,產生目標域圖像和生成的目標域圖像的分割標籤
D s D_s D s :判別生成的源域圖像來自生成的目標域圖像 x s → t x^{s\rightarrow t} x s → t 還是來自真正的目標域圖像 x t x_t x t 的判別器
D p D_p D p :對分割網絡生成的分割標籤進行判別的判別器
L a d v t ( G t , D t ) \mathcal{L}^t_{adv}(G_t,D_t) L a d v t ( G t , D t ) :目標域 GAN(G t , D t G_t,D_t G t , D t )的目標函數
L c y c ( G t , E , U ) \mathcal{L}_{cyc}(G_t,E,U) L c y c ( G t , E , U ) :源域-目標域-源域或目標域-源域-目標域的循環一致性損失
L s e g ( E , C ) \mathcal{L}_{seg}(E,C) L s e g ( E , C ) :分割網絡的混合損失
L a d v p ( E , C , D p ) \mathcal{L}_{a d v}^{p}(E, C, D_{p}) L a d v p ( E , C , D p ) :判別器 D p D_p D p 的對抗損失
L adv s ( E , D s ) \mathcal{L}_{\text {adv }}^{s}(E, D_{s}) L adv s ( E , D s ) :判別器 D s D_s D s 的對抗損失
L adv s ~ ( E , D s ) \mathcal{L}_{\text {adv }}^{\tilde{s}}(E, D_{s}) L adv s ~ ( E , D s ) :判別器 D s D_s D s 輔助任務的對抗損失
三、方法
1. 用於圖像對齊的外觀轉變
由於域偏移,跨域之間的圖片通常看起來不同,而圖像對齊的目的就是減少源域圖像和目標域圖像之間的這種差異。即給定一個有標籤的來自源域的數據集 { x i s , y i s } i = 1 N \{x_i^s,y_i^s\}_{i=1}^N { x i s , y i s } i = 1 N ,以及一個無標籤的來自目標域的數據集 { x i t } j = 1 M \{x_i^t\}_{j=1}^M { x i t } j = 1 M ,使得源域圖像 x i s x_i^s x i s 儘可能的看起來像 目標域圖像 x i t x_i^t x i t 。轉換後的圖像不僅要看起來像來自目標域,而且還應該保留源域的結構語義內容。
上圖是網絡的整體結構示意圖,可結合以下描述來加以理解。
(1)外觀轉變
使用一個生成器 G t G_t G t 將源域圖像轉換成與目標域相似的圖像,即 G t ( x s ) = x s → t G_t(x^s)=x^{s\rightarrow t} G t ( x s ) = x s → t ,並使用一個判別器 D t D_t D t 來判斷生成的圖像是真正來自目標域還是生成的。這個 GAN 的目標函數爲:
L adv t ( G t , D t ) = E x t ∼ X t [ log D t ( x t ) ] + E x s ∼ X s [ log ( 1 − D t ( G t ( x s ) ) ) ]
\begin{aligned}
\mathcal{L}_{\text {adv}}^{t}\left(G_{t}, D_{t}\right)=& \mathbb{E}_{x^{t} \sim X^{t}}\left[\log D_{t}\left(x^{t}\right)\right]+\\
& \mathbb{E}_{x^{s} \sim X^{s}}\left[\log \left(1-D_{t}\left(G_{t}\left(x^{s}\right)\right)\right)\right]
\end{aligned}
L adv t ( G t , D t ) = E x t ∼ X t [ log D t ( x t ) ] + E x s ∼ X s [ log ( 1 − D t ( G t ( x s ) ) ) ]
爲了讓轉換得到的圖像 x s → t x^{s\rightarrow t} x s → t 保留源域的特徵,通常使用一個反向的生成器來促進圖像的循環一致性。圖中的 E 是特徵編碼器,U 是解碼器,E 和 U 加起來就相當於一個生成器 G s G_s G s ,即 G s = E ∘ U G_s=E\circ U G s = E ∘ U ,它可以將轉換得到的目標域圖像 x s → t x^{s\rightarrow t} x s → t 再轉換回源域。並通過源域的判別器 D s D_s D s 進行判別,其對抗損失爲 L a d v s \mathcal{L}_{adv}^s L a d v s ,和目標域上的 GAN 的訓練方式一致。通過源域-目標域-源域(x s → t → s = U ( E ( G t ( x s ) ) ) x^{s \rightarrow t \rightarrow s}=U\left(E\left(G_{t}\left(x^{s}\right)\right)\right) x s → t → s = U ( E ( G t ( x s ) ) ) )或目標域-源域-目標域(x t → s → t = G t ( U ( E ( x t ) ) ) x^{t \rightarrow s \rightarrow t}=G_{t}\left(U\left(E\left(x^{t}\right)\right)\right) x t → s → t = G t ( U ( E ( x t ) ) ) )的轉換就得到了圖像的循環一致性損失,即:
L c y c ( G t , E , U ) = E x s ∼ X s ∥ U ( E ( G t ( x s ) ) ) − x s ∥ 1 + E x t ∼ X t ∥ G t ( U ( E ( x t ) ) ) − x t ∥ 1
\begin{aligned}
\mathcal{L}_{\mathrm{cyc}}\left(G_{t}, E, U\right)=& \mathbb{E}_{x^{s} \sim X^{s}}\left\|U\left(E\left(G_{t}\left(x^{s}\right)\right)\right)-x^{s}\right\|_{1}+\\
& \mathbb{E}_{x^{t} \sim X^{t}}\left\|G_{t}\left(U\left(E\left(x^{t}\right)\right)\right)-x^{t}\right\|_{1}
\end{aligned}
L c y c ( G t , E , U ) = E x s ∼ X s ∥ U ( E ( G t ( x s ) ) ) − x s ∥ 1 + E x t ∼ X t ∥ ∥ G t ( U ( E ( x t ) ) ) − x t ∥ ∥ 1
(4)目標域的分割網絡
圖中的 C 是一個像素級的分類器,E 和 C 加起來 E ∘ C E\circ C E ∘ C 就相當於一個目標域的分割網絡,它的輸入包括 x s → t , y s , x t x^{s\rightarrow t},y^s,x^t x s → t , y s , x t ,輸出是 x s → t , x t x^{s\rightarrow t},x^t x s → t , x t 的分割標籤,分割網絡通過最小化一個混合損失(分割損失)來優化:
L s e g ( E , C ) = H ( y s , C ( E ( x s → t ) ) + Dice ( y s , C ( E ( x s → t ) ) )
\mathcal{L}_{s e g}(E, C)=H\left(y^{s}, C\left(E\left(x^{s \rightarrow t}\right)\right)+\operatorname{Dice}\left(y^{s}, C\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right.
L s e g ( E , C ) = H ( y s , C ( E ( x s → t ) ) + D i c e ( y s , C ( E ( x s → t ) ) )
其中第一項是交叉熵,第二項是 Dice 損失。
2. 特徵對齊的對抗學習
爲解決跨域的域偏移問題,文章提出了另外的判別器來從特徵對齊的角度來減少生成的目標圖像 x s → t x^{s\rightarrow t} x s → t 和真正的目標圖像 x t x^t x t 的 domain gap。爲了對齊以上兩種圖像的特徵,通常的方法是在特徵空間直接使用對抗學習,但是特徵空間一般是高維的,很難直接對齊。所以文章使用的方法是在兩個低維的空間內使用對抗學習,一個是語義預測空間,另一個是生成圖像空間。
(1)在語義預測空間的特徵對齊
使用判別器 D p D_p D p 來對分割網絡生成的分割標籤進行判別,如果兩者的特徵沒有對齊的話,就通過反向傳播對特徵提取器 E 進行優化,從而減小生成的目標域圖像 x s → t x^{s\rightarrow t} x s → t 和真正的目標域圖像 x t x^t x t 的特徵分佈之間的差異。該對抗損失爲:
L a d v p ( E , C , D p ) = E x s → t ∼ X s → t [ log D p ( C ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D p ( C ( E ( x t ) ) ) ) ]
\begin{aligned}
\mathcal{L}_{a d v}^{p}\left(E, C, D_{p}\right)=& \mathbb{E}_{x^{s \rightarrow t} \sim X^{s \rightarrow t}\left[\log D_{p}\left(C\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right]+} \\
& \mathbb{E}_{x^{t} \sim X^{t}\left[\log \left(1-D_{p}\left(C\left(E\left(x^{t}\right)\right)\right)\right)\right]}
\end{aligned}
L a d v p ( E , C , D p ) = E x s → t ∼ X s → t [ log D p ( C ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D p ( C ( E ( x t ) ) ) ) ]
(2)語義預測空間的深度監督對抗學習
低級特徵可能和高級特徵的對齊情況並不一樣,所以使用額外的和編碼器低層的輸出相關的像素級分類器來產生額外的輔助預測,然後通過一個判別器來對這些額外預測進行判別。這增強了低級特徵的對齊,如此一來,L s e g \mathcal{L}_{seg} L s e g 和 L a d v \mathcal{L}_{adv} L a d v 的表達式就需要進行調整了,它們分別被拓展爲 L s e g i ( E , C i ) \mathcal{L}_{seg}^i(E,C_i) L s e g i ( E , C i ) 和 L a d v P i ( E , C i , D p i ) \mathcal{L}_{adv}^{P_i}(E,C_i,D_{p_i}) L a d v P i ( E , C i , D p i ) ,其中 i = 1 , 2 i={1,2} i = 1 , 2 ,C 1 , C 2 C_1,C_2 C 1 , C 2 表示連接到編碼器不同層的兩個分類器,D p 1 , D p 2 D_{p_1},D_{p_2} D p 1 , D p 2 表示對兩個分類器的輸出進行判別的判別器。
(4)生成圖像空間的特徵對齊
對於生成器 E ∘ U E\circ U E ∘ U ,爲判別器 D s D_s D s 增加一個輔助任務——判別生成的源域圖像來自生成的目標域圖像 x s → t x^{s\rightarrow t} x s → t 還是來自真正的目標域圖像 x t x^t x t 。該輔助任務的對抗損失爲:
L adv s ~ ( E , D s ) = E x s → t ∼ X s → t [ log D s ( U ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D s ( U ( E ( x t ) ) ) ) ]
\begin{aligned}
\mathcal{L}_{\text {adv }}^{\tilde{s}}\left(E, D_{s}\right)=& \mathbb{E}_{x^{s \rightarrow} t \sim X^{s \rightarrow t}}\left[\log D_{s}\left(U\left(E\left(x^{s \rightarrow t}\right)\right)\right)\right]+\\
& \mathbb{E}_{x^{t} \sim X^{t}}\left[\log \left(1-D_{s}\left(U\left(E\left(x^{t}\right)\right)\right)\right)\right]
\end{aligned}
L adv s ~ ( E , D s ) = E x s → t ∼ X s → t [ log D s ( U ( E ( x s → t ) ) ) ] + E x t ∼ X t [ log ( 1 − D s ( U ( E ( x t ) ) ) ) ]
3. 用於協同學習的共享編碼器
在協同學習框架的一個關鍵是在圖像和特徵對齊之間共享編碼器 E,編碼器 E 會通過損失 L a d v s \mathcal{L}_{adv}^s L a d v s 和 L c y c \mathcal{L}_{cyc} L c y c ,以及判別器 D p i , D s D_{p_i},D_s D p i , D s 的反向傳播來進行優化。
在訓練時各個模塊的訓練順序爲:G t → D t → E → C i → U → D s → D p i G_t\rightarrow D_t \rightarrow E \rightarrow C_i \rightarrow U \rightarrow D_s \rightarrow D_{p_i} G t → D t → E → C i → U → D s → D p i 。整個網絡的目標函數爲:
L = L a d v t ( G t , D t ) + λ a d v s L a d v s ( E , U , D s ) + λ g s L c s c ( G t , E , U ) + λ seg 1 L seg 1 ( E , C 1 ) + λ seg 2 L seg 2 ( E , C 2 ) + λ a d v p 1 L a d v p 1 ( E , C , D p 1 ) + λ adv p 2 L a d v p 2 ( E , C , D p 2 ) + λ a d v s ~ L a b s ~ ( E , D s )
\begin{aligned}
\mathcal{L}=& \mathcal{L}_{a d v}^{t}\left(G_{t}, D_{t}\right)+\lambda_{a d v}^{s} \mathcal{L}_{a d v}^{s}\left(E, U, D_{s}\right)+\\
& \lambda_{\mathrm{gs}} \mathcal{L}_{\mathrm{csc}}\left(G_{t}, E, U\right)+\lambda_{\operatorname{seg}}^{1} \mathcal{L}_{\operatorname{seg}}^{1}\left(E, C_{1}\right)+\\
& \lambda_{\operatorname{seg}}^{2} \mathcal{L}_{\operatorname{seg}}^{2}\left(E, C_{2}\right)+\lambda_{a d v}^{p_{1}} \mathcal{L}_{a d v}^{p_{1}}\left(E, C, D_{p_{1}}\right)+\\
& \lambda_{\text {adv}}^{p_{2}} \mathcal{L}_{a d v}^{p_{2}}\left(E, C, D_{p_{2}}\right)+\lambda_{a d v}^{\tilde{s}} \mathcal{L}_{a b}^{\tilde{s}}\left(E, D_{s}\right)
\end{aligned}
L = L a d v t ( G t , D t ) + λ a d v s L a d v s ( E , U , D s ) + λ g s L c s c ( G t , E , U ) + λ s e g 1 L s e g 1 ( E , C 1 ) + λ s e g 2 L s e g 2 ( E , C 2 ) + λ a d v p 1 L a d v p 1 ( E , C , D p 1 ) + λ adv p 2 L a d v p 2 ( E , C , D p 2 ) + λ a d v s ~ L a b s ~ ( E , D s )
其中 { λ a d v s , λ c y c , λ s e g 1 , λ s e g 2 , λ a d v p 1 , λ a d v p 2 , λ a d v s ~ } \left\{\lambda_{a d v}^{s}, \lambda_{c y c}, \lambda_{s e g}^{1}, \lambda_{s e g}^{2}, \lambda_{a d v}^{p_{1}}, \lambda_{a d v}^{p_{2}}, \lambda_{a d v}^{\tilde{s}}\right\} { λ a d v s , λ c y c , λ s e g 1 , λ s e g 2 , λ a d v p 1 , λ a d v p 2 , λ a d v s ~ } 是用於平衡各項的參數,在實驗時分別設爲 { 0.1 , 10 , 1.0 , 0.1 , 0.1 , 0.01 , 0.1 } \{0.1,10,1.0,0.1,0.1,0.01,0.1\} { 0 . 1 , 1 0 , 1 . 0 , 0 . 1 , 0 . 1 , 0 . 0 1 , 0 . 1 } 。
四、網絡設置和實施細節
生成器 G t G_t G t 採用的是和 CycleGAN 中一樣的設置,包括3個卷積層,9個殘差塊,2個反捲積層,然後再通過一個卷積層獲得生成的圖像。
解碼器 U 包括1個卷積層,4個殘差塊,3個反捲積層,然後再通過一個卷積層獲得輸出。
判別器 { D t , D s , D p } \{D_t,D_s,D_p\} { D t , D s , D p } 採用的是和 PatchGAN 一樣的設置,它的輸入是 70 × 70 70\times70 7 0 × 7 0 的patches,它包括5個卷積層,除了最後兩層卷積層步長爲1,其他的卷積核爲4,步長爲2。特徵圖的個數分別爲 { 64 , 128 , 256 , 512 , 1 } \{64,128,256,512,1\} { 6 4 , 1 2 8 , 2 5 6 , 5 1 2 , 1 } 。在前四層卷積層中每個卷積層後都跟着一個實例正則化和一個0.2的 Leaky ReLU。
編碼器 E 使用殘差連接和空洞卷積(dilation rate=2),來擴大分辨率的大小。用 { C k , R k , D k } \{Ck,Rk,Dk\} { C k , R k , D k } 分別表示通道數爲 k k k 的卷積層、殘差塊和空洞殘差塊;用 M 表示步長爲 2 的最大池化層;則編碼器的構成爲 { C 16 , R 16 , M , R 32 , M , 2 × R 64 , M , 2 × R 128 , 4 × R 256 , 2 × R 512 , 2 × D 512 , 2 × C 512 } \{C16,R16,M,R32,M,2\times R64,M,2\times R128,4\times R256,2\times R512,2\times D512,2\times C512\} { C 1 6 , R 1 6 , M , R 3 2 , M , 2 × R 6 4 , M , 2 × R 1 2 8 , 4 × R 2 5 6 , 2 × R 5 1 2 , 2 × D 5 1 2 , 2 × C 5 1 2 } 。每個卷積操作後都跟着一個批正則化和 ReLU 激活函數。
像素級分類器 C 1 C_1 C 1 連接到編碼器 E 的最後一層(2 × C 512 2\times C512 2 × C 5 1 2 )後面來得到輸出,C 2 C_2 C 2 最後連接到編碼器 E 的 2 × R 512 2\times R512 2 × R 5 1 2 塊的後面來得到輸出。C 1 , C 2 C_1,C_2 C 1 , C 2 都只包含一個 1 × 1 1\times1 1 × 1 的卷積層。
batch size 爲8,使用的是 Adam 優化器,學習率爲 2 × 1 0 − 4 2\times 10^{-4} 2 × 1 0 − 4 。