Part3: Dive into DDPM

背景

整個系列有相對完整的公式推導,若正文中有涉及到的省略部分,皆額外整理在Part4,並會在正文中會指明具體位置。

Part2基於\(\text{Variational Inference}\),找到原目標函數\(-\ln{p_\theta(x_0)}\)的上界\(L\),定義如下:

\[\begin{aligned} L := & \mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\ =& \mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]\end{aligned}\tag{1} \]

沿着論文的思路對\(L\)繼續精簡,得到最終在代碼層面實現的損失函數\(L_{simple}\)。同樣的,補充的推導見Part4;“擴散過程”的梗概介紹見Part1

簡化過程

不難看出\(L\)中的每一項皆爲KL散度。回顧forward processreverse process兩個階段的定義,馬爾可夫鏈的狀態轉移皆服從高斯分佈,如下所示:

\[\begin{aligned} q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) &:= \mathcal{N}\left(\mathbf{x}_t ; \sqrt{1-\beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}\right) \\ p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)&:=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \boldsymbol{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right) \end{aligned}\tag{2}\]

同時,經過推導(見Part4推導二),易知:

\[\begin{aligned} q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) & =\mathcal{N}\left(\mathbf{x}_{t-1} ; \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right), \tilde{\beta}_t \mathbf{I}\right) \\ \text { where } \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right) & :=\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0+\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t, \ \ \tilde{\beta}_t:=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t\end{aligned}\tag{3}\]

KL散度比較皆發生在兩個Gaussian間。

\(L_{t}\)的簡化

可以看到,\((1)\)式中的\(L_T\)代表前向擴散過程,與待求解的參數項\(\theta\)無關,因此可被忽略:

\[\arg \min_{\theta} (L) \iff \arg \min_{\theta} \left(\mathbb{E}_q\left[ \sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}\right]\right) \]

注:在2015年提出diffusion框架的論文,前向擴散過程中的\(\beta_t\)是可以被學習的參數,故此處可視作DDPM第一處簡化

\(L_{t-1}\)的簡化

對於反向擴散過程的分佈\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\),共涉及到兩組參數\(\boldsymbol{\mu}_\theta\)\(\boldsymbol{\Sigma}_\theta\)DDPM第二處簡化是定義\(\boldsymbol{\Sigma}\)爲常數\(\sigma_t^2\),在計算中使用\(\beta_t\)\(\tilde{\beta}_t\)代替,故\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \sigma_t^2\mathbf{I}\right)\)

基於Part4兩個高斯的KL散度,對於\(L_{t-1}\),有:

\[L_{t-1}=\mathbb{E}_q\left[\frac{1}{2 \sigma_t^2}\left\|\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right]+C \tag{4} \]

其中\(C\)是個常數項。

仔細觀察\((4)\)不難發現,想要目標函數最小化,則\(\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)\)\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)間的“距離必須要近”。也就是說,深度網絡通過訓練,使得\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)趨近於\(\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)\)爲了使訓練更加簡單,嘗試對\((4)\)式改寫。

Foward Process中,\(\mathbf{x}_t\)可由\(\mathbf{x}_0\)\(\epsilon\)表示(見Part4推導一),不妨將\(\mathbf{x}_t\)記作\(\mathbf{x}_t({\mathbf{x}_0, \epsilon})\),故\(\mathbf{x}_0\)可以展開表示爲\(\mathbf{x}_t({\mathbf{x}_0, \epsilon})\)\(\epsilon\)的差:

\[\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \epsilon \ \Rightarrow \ \mathbf{x}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_t}\epsilon)\tag{5} \]

又因爲\((3)\)式,故有:

\[\begin{aligned} \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right) &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} \mathbf{x}_0+\frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t \\ &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1-\bar{\alpha}_t} * \frac{\mathbf{x}_t - \sqrt{1 - \bar{\alpha}_{t}}\epsilon}{\sqrt{\bar{\alpha}_t}} + \frac{\sqrt{\alpha_t}\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_t} \mathbf{x}_t \\ &= \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon\right) \end{aligned}\tag{6-1} \]

前文提到,要優化\((4)\)式,則必然有:\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) \to \tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)\)。其中,\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)是深度網絡的輸出(預測)結果,\(\mathbf{x}_t\)\(t\)作爲模型的輸入參數。

\(\text{(6-1)}\)可知\(\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)\)能展開爲\(\mathbf{x}_t\)\(\epsilon\)的表達,\(\mathbf{x}_t\)已知,那不妨令原本要預測\(\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\)的深度網絡直接預測\(\boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\),變換前後依然等價。即

\[\begin{aligned} \boldsymbol{\mu}_{\theta^{\prime}}\left(\mathbf{x}_t, t\right) \iff \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right) \end{aligned} \tag{6-2} \]

此處以\(\theta^{\prime}\)\(\theta\)對變換前後的深度網絡參數進行區分,故\(\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right)\)需要無限趨近於\(\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)\)

\(\text{(6-1)}\)\(\text{(6-2)}\)代入\((4)\)式,有:

\[\begin{aligned} & L_{t-1}-C^\prime \\ =&\mathbb{E}_q\left[\frac{1}{2 \sigma_t^2}\left\|\tilde{\boldsymbol{\mu}}_t\left(\mathbf{x}_t, \mathbf{x}_0\right)-\boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right)\right\|^2\right] \\ \iff& \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\frac{1}{2 \sigma_t^2}\left\|\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t\left(\mathbf{x}_0, \boldsymbol{\epsilon}\right)-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}\right)-\frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t\left(\mathbf{x}_0, \boldsymbol{\epsilon}\right)-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon_\theta}(\mathbf{x}_t, t)\right)\right\|^2\right] \\ =& \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\frac{\beta_t^2}{2 \sigma_t^2 \alpha_t\left(1-\bar{\alpha}_t\right)}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)\right\|^2\right] \end{aligned} \tag{7} \]

對比\((4)\)\((7)\),不難發現參數\(\theta\)作用的對象發生變化。在\((4)\)中,\(\theta\)的參數化對象爲高斯分佈的均值\(\boldsymbol{\mu}\);而在\((7)\)中,\(\theta\)的參數化對象轉移到\(\boldsymbol{\epsilon}\)實際上,不僅可以參數化\(\boldsymbol{\mu}\)\(\boldsymbol{\epsilon}\),也可以參數化\(\mathbf{x}_0\),只需要對\((5)\)中表示的主體進行變換即可。

並且,重新審視\(\text{(6-2)}\),該式與Part1中的採樣算法聯繫上了。上述目標函數的設定及推理,皆是爲了獲取反向過程的分佈\(p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) :=\mathcal{N}\left(\mathbf{x}_{t-1} ; \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right), \boldsymbol{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right)\)

通過公式\(\text{(6-2)}\),按照反向過程相鄰狀態間的圖像轉換服從高斯分佈的定義,反向過程中知曉\(\mathbf{x}_t\)\(t\)後,通過深度網絡預測出\(\boldsymbol{\epsilon}\),再基於此求出\(\boldsymbol{\mu}_{t}\),結合自定義的\(\boldsymbol{\sigma}_{t}\),可採樣得到\(\mathbf{x}_{t-1}\),便實現反向過程的一次“降噪”。

\(L_{0}\)的簡化

這一項對應着信息由隱變量轉變回\(\mathbf{x_0}\),故而需要特殊考慮。

真實圖片中各個像素由0到255的數值組成,在處理時通常將所有像素值歸一化到區間[-1,1]。論文中將該項對應的優化目標定義爲:

\[\begin{aligned} p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(\mathbf{x}_1, 1\right), \sigma_1^2\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll}\infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1\end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned} \tag{8} \]

其中,積分項是爲了與圖片真實像素的離散特性保持一致,\(D\)爲像素點的個數。

該項的優化目標是:對於輸入圖片\(x_0\)的所有像素位置,使得基於神經網絡產生的高斯分佈在該位置的採樣結果,與\(x_0\)對應位置的真實值相差不大。

直接文字闡述並不好理解,下方是對於單個位置的具體實例,截圖來自視頻
image

當前有一張真實的圖片\(x_0\),對應上圖內靠左邊的圖片,經過縮放後,在位置\(i\)的值爲\(x^i_0 = \frac{10}{255}\)

並且,中間圖片表示在\(x_1^i\)(此時還處於有噪聲狀態),經過神經網絡模型,預測出該位置的值是服從均值爲\(\frac{11}{255}\)的高斯分佈\(\mathcal{N^1}\)

在左下角畫出該\(\mathcal{N^1}\)的概率密度曲線,此時積分的上下界爲\((\frac{9}{255}, \frac{11}{255})\),從圖上可以直觀地看出積分對應的陰影面積相對來說比較大。故基於此採樣得到的\(\hat{x_0}^i\)與輸入圖片\(x_0^i\)接近的置信度很高。在訓練時反映出來的是,神經網絡在該位置的預測表現對\((8)\)式即Loss的貢獻程度較低;

但如果神經網絡預測出該位置的值服從服從均值爲\(\frac{105}{255}\)的高斯分佈\(\mathcal{N^2}\),此時概率密度曲線整體會往右平移,\((\frac{9}{255}, \frac{11}{255})\)區域屬於長尾位置,顯然積分結果比較小,從側面來說,基於此採樣得到的\(\hat{x_0}^{i_\prime}\)與輸入圖片\(x_0^i\)接近的置信度很低,在訓練時對Loss的貢獻程度高,在反向傳播時的梯度也大。

實際代碼實現中,該項被省略,這是第三處簡化

簡化的損失函數

回顧\((1)\)式,目前只剩下以\(L_{t-1}\)爲主體的求和部分,如下所示:

\[\begin{aligned} \arg \min_{\theta} (L) & \iff \arg \min_{\theta} \mathbb{E}_q\left[ \sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \right] \\ & \iff \arg \min_{\theta} \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[\frac{\beta_t^2}{2 \sigma_t^2 \alpha_t\left(1-\bar{\alpha}_t\right)}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)\right\|^2\right] \end{aligned} \tag{9} \]

對於\((9)\)式,DDPM第四處簡化在於省略了均方差損失項的權重,最終的損失函數\(L_{simple}\)爲:

\[\begin{aligned} L_{\text {simple }}(\theta):=\mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}, t\right)\right\|^2\right] \end{aligned} \]

總結

回顧本文,DDPM在損失函數上做了很多簡化,對於代碼側的實現非常友好。同時,論文作者也給出實驗對比,驗證簡化並不會使得結果變差,有些簡化(比如設置reverse過程中的\(\Sigma\)爲非參數項)甚至取得大幅度的提升效果。

Reference

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章