Part2: DDPM as Example of Variational Inference

很多次翻看DDPM,始終不太能理解論文中提到的\(\text{Variational Inference}\)到底是如何在這個工作中起到作用。五一假期在家,無意間又刷到徐亦達老師早些年錄製的理論視頻,沒想到其中也有介紹這部分的內容。老師的上課方式總是娓娓道來,把每一步都講解得很仔細。本文記錄一下個人對開頭問題的思考。

Background

如果需要簡略地介紹一下DDPM這個工作,可能會用以下幾句話簡單地描述:DDPMMarkov的形式對數據(圖片)“擴散過程”建模,使用神經網絡進行訓練擬合,學習數據的概率分佈。

所以對於生成任務來說,希望從給定數據中學習到的是數據的潛在信息。比如圖片生成,在給定一些圖片後,模型學習到的是“正常圖片長什麼樣子”,如:

  1. 一張包含手機正面的圖片會有【手機屏幕】;
  2. 一張包含貓咪的圖片會有人們觀察到的貓咪模樣;
  3. ...

對於圖片中每個像素點和附近的像素點,進行“合理”佈局,才能生成“符合人們認知的圖片”。

圖片生成能像常見的機器學習任務如分類任務、迴歸任務,能基於maximize likelihood的形式來訓練麼?

結論是很難,先回顧如何做maximum likelihood。給定一批數據,首先需要假定數據服從的分佈,接着寫出似然函數,之後直接通過解析解的形式或是梯度下降的形式,求出分佈。

問題就出在假定分佈這一步,沒有人知道圖片客觀上服從什麼分佈。那如果使用神經網絡直接擬合可以麼?這好像也不現實,拿一張512*512*3的圖片來說,網絡輸出層共有約75w的數值。

對於圖片生成還有另外一個問題,世界上的圖片太多了,目之所及稍做處理,皆爲圖片。即便使用神經網絡能擬合,最後生成的圖片很難存在多樣性。

那目前圖片生成模型都是怎麼做的,比如VAE或是本文即將要介紹的Diffusion Model,它們學習的都是數據分佈\(p(x)\),但直接求\(p(x)\)這麼麻煩,需要怎麼做?這其實也是\(\text{Variational Inference}\)的核心思想,“曲線救國”,通過引入其它分佈,將原本難以優化的問題轉變爲可優化問題。

ELOB

先把上述提到的所有背景先拋開,研究一下\(p(x)\),看看能得到什麼有意思的結論。

a. 基於條件概率分佈,引入新的隨機變量\(z\)\(p(x) = \frac{p(x, z)}{p(z\mid x)}\)

b. 對於兩邊同時取\(\ln\),等式依然成立,因此有:\(\ln{p(x)} = \ln{\frac{p(x, z)}{p(z \mid x)}}\)

c. 右邊分子分母同乘以\(q(z)\)\(\ln{p(x)} = \ln{\frac{p(x, z) * q(z)}{p(z \mid x) * q(z)}} = \ln{\left(\frac{p(x, z)}{q(z)} * \frac{q(z)}{p(z \mid x)}\right)} = \ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}}\)

d. 再次,對於上式左右兩邊求關於\(q(z)\)的期望,等式依然成立:

\[\begin{aligned} &\mathbb{E}_{z\sim q(z)}{[\ln{p(x)}]} = \mathbb{E}_{z\sim q(z)}{(\ln{\frac{p(x, z)}{q(z)}} + \ln{\frac{q(z)}{p(z \mid x)}})} \\ \iff & \int_z q(z)\ln{p(x)}dz = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \\ \iff & \ln{p(x)} = \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz + \int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz \end{aligned} \tag{1} \]

一系列變換後,\((1)\)式是最後的推導結果,等式右邊由兩個項組成。第二個項\(\int_z q(z)\ln{\frac{q(z)}{p(z \mid x)}}dz\),叫做KL散度,它被用來衡量兩個分佈之間的“距離”,性質是值不小於0

這樣一來,通過\((1)\)可以得到不等式\((2)\)

\[\begin{equation*} \ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz \end{equation*} \tag{2} \]

\((1)\)式右邊的第一項,同時也是\((2)\)式的右邊項,被學者們叫做\(\text{ELBO(Evidence Lower Bound)}\)

Objective Function

上述推導的\((2)\)式可以被視作“定理”一般的存在,即對於某個分佈的對數形式,總可以找到它的下界。

\((2)\)式可以用來做什麼?在Background中提到,圖片生成任務中的\(p(x)\)想要對它做maximum likelihood根本無法做起。目標依然是最大化\(p(x)\),但有了\((2)\)式,求解的目標可以轉移到最大化它的下界\(\text{ELBO}\)

這也是論文中提到的:

This paper presents progress in diffusion probabilistic models. A diffusion probabilistic model (which we will call a “diffusion model” for brevity) is a parameterized Markov chain trained using variational inference to produce samples matching the data after finite time.

接下來,回到論文中,看看是如何一步步推導出DDPM的優化目標。\((3)\)式直接摘錄於論文:

\[\begin{equation*} \ln{p(x)} \geq \int_z q(z)\ln{\frac{p(x, z)}{q(z)}}dz = \mathbb{E}_{z \sim q(z)}\left[\ln{\frac{p(x,z)}{q(z)}}\right] \end{equation*} \tag{2} \]

\[\begin{equation*} \mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right] \leq \mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right]=: L \end{equation*} \tag{3} \]

下面一項項地對\((3)\) 進行拆解,並且將它與\((2)\)比對,能幫助更好地理解:

  1. \((3)\)不等號左邊的\(\mathbb{E}\left[-\log p_\theta\left(\mathbf{x}_0\right)\right]\)進一步化簡就是\(-\log p_\theta\left(\mathbf{x}_0\right)\)。其中,\(p_\theta\left(\mathbf{x}_0\right)\)便是模型要學習的最終目標:圖像的分佈,\(\theta\)是模型的參數,\(\mathbf{x}_0\)是圖片;

  2. \((2)\)式的左右兩邊同時加上符號,\(\geq\)變爲\(\leq\)

  3. \((3)\)不等式右邊部分,\(\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right]\)

    1. 很明顯,\(q(\mathbf{x}_{1:T} \mid \mathbf{x}_0)\)相當於\((2)\)中引入的額外分佈\(q(z)\)。對於\(z\),在生成模型中會給它一個稱呼:隱變量\((\text{latent})\)。實際上,在diffusion models裏,對\(\mathbf{x}_0\)加噪後的\(\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\)就可以看作隱變量,那不妨記作\(z := \{\mathbf{x}_1,\mathbf{x}_2,\ldots, \mathbf{x}_T\}\)

    2. \(p_\theta\left(\mathbf{x}_{0: T}\right) = p_\theta\left(\mathbf{x}_{0}, \mathbf{x}_{1}, \ldots, \mathbf{x}_{T}\right)\),是關於\(\mathbf{x}_0, z\)的聯合概率分佈,因爲選用馬爾代夫鏈建模,那麼依據馬爾可夫鏈的性質,論文定義:

\[\begin{equation*} \begin{aligned} q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)&:=\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) \\ p_\theta\left(\mathbf{x}_{0: T}\right)&:=p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) \end{aligned} \end{equation*} \tag{4} \]

  1. \((4)\)帶入\((3)\)不等式右邊的第一項,得到\(L\)

\[\begin{equation*} \begin{aligned} &\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ =&\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right) \prod_{t=1}^T p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{\prod_{t=1}^T q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ =&\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] := L \end{aligned} \end{equation*} \]

到目前爲止,經過了很多輪的變換以及數學公式,先捋一遍,再往下。\(L\)是一個替代的優化目標,

\[\mathop{\arg\min}{(L)} \iff \mathop{\arg\min}{(-\ln{p}_{\theta}(\mathbf{x}_0))} \iff \mathop{\arg\max}{(\ln{p}_{\theta}(\mathbf{x}_0))} \]

接下來,論文中對\(L\)進行了重寫,以下步驟直接摘錄自論文\(\text{Appendix A}\)

\[\begin{equation*} \begin{aligned} L & =\mathbb{E}_q\left[-\log \frac{p_\theta\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \mid \mathbf{x}_0\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t \geq 1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}\right] \\ & =\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)}-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \\ &=\mathbb{E}_q\left[-\log p\left(\mathbf{x}_T\right)-\sum_{t>1} \log \left[\frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}\right]-\log \frac{p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}{q\left(\mathbf{x}_1 \mid \mathbf{x}_0\right)}\right] \end{aligned} \end{equation*} \tag{5} \]

倒數兩步的變換髮生在第二項,具體依據爲:

\[\begin{aligned} q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) =& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1}\right)}{q\left(\mathbf{x}_{t-1}\right)} \\ =& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) *q(\mathbf{x}_{0})}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) * q(\mathbf{x}_{0})} \\ =& \frac{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) }{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \end{aligned} \quad \Rightarrow \quad \begin{aligned} &\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right)} \\ =& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_t, \mathbf{x}_{t-1} \mid \mathbf{x}_{0}\right) } \cdot {q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)} \\ =& \sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)} \cdot \frac{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_0\right)}{q\left(\mathbf{x}_t \mid \mathbf{x}_0\right)} \end{aligned} \]

接着對\((5)\)進行改寫得到最終形式\((6)\)

\[\begin{aligned} L &=\mathbb{E}_q\left[-\log \frac{p\left(\mathbf{x}_T\right)}{q\left(\mathbf{x}_T \mid \mathbf{x}_0\right)}-\sum_{t>1} \log \frac{p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)}{q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right)}-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right] \\ &=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)}_{L_0}] \end{aligned} \tag{6} \]

Summary

太好了,對於\((6)\)來說,它最起碼是個可以優化的目標函數了,因爲論文中定義馬爾可夫鏈相鄰狀態的轉變是服從高斯分佈的。當然在論文中,\((6)\)還會進一步被改寫,得到更加精簡的\(\text{loss function}\)形式。
DDPM是應用\(\text{variational inference}\)進行優化求解的典型例子,很值得借鑑學習。

Reference

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章