很多次翻看DDPM,始終不太能理解論文中提到的\(\text{Variational Inference}\)到底是如何在這個工作中起到作用。五一假期在家,無意間又刷到徐亦達老師早些年錄製的理論視頻,沒想到其中也有介紹這部分的內容。老師的上課方式總是娓娓道來,把每一步都講解得很仔細。本文記錄一下個人對開頭問題的思考。
Background
如果需要簡略地介紹一下DDPM這個工作,可能會用以下幾句話簡單地描述:DDPM
以Markov
的形式對數據(圖片)“擴散過程”建模,使用神經網絡進行訓練擬合,學習數據的概率分佈。
所以對於生成任務來說,希望從給定數據中學習到的是數據的潛在信息。比如圖片生成,在給定一些圖片後,模型學習到的是“正常圖片長什麼樣子”,如:
- 一張包含手機正面的圖片會有【手機屏幕】;
- 一張包含貓咪的圖片會有人們觀察到的貓咪模樣;
- ...
對於圖片中每個像素點和附近的像素點,進行“合理”佈局,才能生成“符合人們認知的圖片”。
圖片生成能像常見的機器學習任務如分類任務、迴歸任務,能基於maximize likelihood
的形式來訓練麼?
結論是很難,先回顧如何做maximum likelihood
。給定一批數據,首先需要假定數據服從的分佈,接着寫出似然函數,之後直接通過解析解的形式或是梯度下降的形式,求出分佈。
問題就出在假定分佈這一步,沒有人知道圖片客觀上服從什麼分佈。那如果使用神經網絡直接擬合可以麼?這好像也不現實,拿一張512*512*3
的圖片來說,網絡輸出層共有約75w
的數值。
對於圖片生成還有另外一個問題,世界上的圖片太多了,目之所及稍做處理,皆爲圖片。即便使用神經網絡能擬合,最後生成的圖片很難存在多樣性。
那目前圖片生成模型都是怎麼做的,比如VAE
或是本文即將要介紹的Diffusion Model
,它們學習的都是數據分佈\(p(x)\),但直接求\(p(x)\)這麼麻煩,需要怎麼做?這其實也是\(\text{Variational Inference}\)的核心思想,“曲線救國”,通過引入其它分佈,將原本難以優化的問題轉變爲可優化問題。
ELOB
先把上述提到的所有背景先拋開,研究一下\(p(x)\),看看能得到什麼有意思的結論。
a. 基於條件概率分佈,引入新的隨機變量\(z\):\(p(x) = \frac{p(x, z)}{p(z\mid x)}\);
b. 對於兩邊同時取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\);
c. 右邊分子分母同乘以\(q(z)\):\(\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}\)
d. 再次,對於上式左右兩邊求關於\(q(z)\)的期望,等式依然成立:
\[\begin{aligned}
&\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\
\iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\
\iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz
\end{aligned}
\tag{1}
\]
一系列變換後,\((1)\)式是最後的推導結果,等式右邊由兩個項組成。第二個項\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用來衡量兩個分佈之間的“距離”,性質是值不小於0。
這樣一來,通過\((1)\)可以得到不等式\((2)\):
\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz
\end{equation*}
\tag{2}
\]
\((1)\)式右邊的第一項,同時也是\((2)\)式的右邊項,被學者們叫做\(\text{ELBO(Evidence Lower Bound)}\)。
Objective Function
上述推導的\((2)\)式可以被視作“定理”一般的存在,即對於某個分佈的對數形式,總可以找到它的下界。
那\((2)\)式可以用來做什麼?在Background
中提到,圖片生成任務中的\(p(x)\)想要對它做maximum likelihood
根本無法做起。目標依然是最大化\(p(x)\),但有了\((2)\)式,求解的目標可以轉移到最大化它的下界\(\text{ELBO}\)。
這也是論文中提到的:
This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.
接下來,回到論文中,看看是如何一步步推導出DDPM
的優化目標。\((3)\)式直接摘錄於論文:
\[\begin{equation*}
\ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right]
\end{equation*}
\tag{2}
\]
\[\begin{equation*}
\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right]=: L
\end{equation*}
\tag{3}
\]
下面一項項地對\((3)\) 進行拆解,並且將它與\((2)\)比對,能幫助更好地理解:
-
\((3)\)不等號左邊的\(\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right]\)進一步化簡就是\(-\log p_\theta\left(\mathbf{x}_0\right)\)。其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要學習的最終目標:圖像的分佈,\(\theta\)是模型的參數,\(\mathbf{x}_0\)是圖片;
-
\((2)\)式的左右兩邊同時加上符號,\(\geq\)變爲\(\leq\);
-
看\((3)\)不等式右邊部分,\(\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]\)
-
很明顯,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)相當於\((2)\)中引入的額外分佈\(q(z)\)。對於\(z\),在生成模型中會給它一個稱呼:隱變量\((\text{latent})\)。實際上,在diffusion models
裏,對\(\mathbf{x}_0\)加噪後的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隱變量,那不妨記作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\);
-
\(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是關於\(\mathbf{x}_0, z\)的聯合概率分佈,因爲選用馬爾代夫鏈建模,那麼依據馬爾可夫鏈的性質,論文定義:
\[\begin{equation*}
\begin{aligned}
q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\
p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)
\end{aligned}
\end{equation*}
\tag{4}
\]
- 將\((4)\)帶入\((3)\)不等式右邊的第一項,得到\(L\):
\[\begin{equation*}
\begin{aligned}
&\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\
=&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\
=&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L
\end{aligned}
\end{equation*}
\]
到目前爲止,經過了很多輪的變換以及數學公式,先捋一遍,再往下。\(L\)是一個替代的優化目標,
\[\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}_0))} \iff \mathop{\arg\max}{(\ln{p}_{\theta}(\mathbf{x}_0))}
\]
接下來,論文中對\(L\)進行了重寫,以下步驟直接摘錄自論文\(\text{Appendix A}\)
\[\begin{equation*}
\begin{aligned}
L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\
&=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right]
\end{aligned}
\end{equation*}
\tag{5}
\]
倒數兩步的變換髮生在第二項,具體依據爲:
\[\begin{aligned}
q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\
=& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}
\end{aligned}
\quad \Rightarrow \quad
\begin{aligned}
&\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\
=& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}
\end{aligned}
\]
接着對\((5)\)進行改寫得到最終形式\((6)\):
\[\begin{aligned}
L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\
&=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}]
\end{aligned}
\tag{6}
\]
Summary
太好了,對於\((6)\)來說,它最起碼是個可以優化的目標函數了,因爲論文中定義馬爾可夫鏈相鄰狀態的轉變是服從高斯分佈的。當然在論文中,\((6)\)還會進一步被改寫,得到更加精簡的\(\text{loss function}\)形式。
DDPM
是應用\(\text{variational inference}\)進行優化求解的典型例子,很值得借鑑學習。
Reference