一文詳解擴散模型:DDPM

作者:京東零售 劉巖

擴散模型講解

前沿

人工智能生成內容(AI Generated Content,AIGC)近年來成爲了非常前沿的一個研究方向,生成模型目前有四個流派,分別是生成對抗網絡(Generative Adversarial Models,GAN),變分自編碼器(Variance Auto-Encoder,VAE),標準化流模型(Normalization Flow, NF)以及這裏要介紹的擴散模型(Diffusion Models,DM)。擴散模型是受到熱力學中的一個分支,它的思想來源是非平衡熱力學(Non-equilibrium thermodynamics)。擴散模型的算法理論基礎是通過變分推斷(Variational Inference)訓練參數化的馬爾可夫鏈(Markov Chain),它在許多任務上展現了超過GAN等其它生成模型的效果,例如最近非常火熱的OpenAI的DALL-E 2,Stability.ai的Stable Diffusion等。這些效果驚豔的模型擴散模型的理論基礎便是我們這裏要介紹的提出擴散模型的文章[1]和非常重要的DDPM[2],擴散模型的實現並不複雜,但其背後的數學原理卻非常豐富。在這裏我會介紹這些重要的數學原理,但省去了這些公式的推導計算,如果你對這些推導感興趣,可以學習參考文獻[4,5,11]的相關內容。我在這裏主要以一個相對簡單的角度來講解擴散模型,幫助你快速入門這個非常重要的生成算法。

1. 背景知識: 生成模型

目前生成模型主要有圖1所示的四類。其中GAN的原理是通過判別器和生成器的互相博弈來讓生成器生成足以以假亂真的圖像。VAE的原理是通過一個編碼器將輸入圖像編碼成特徵向量,它用來學習高斯分佈的均值和方差,而解碼器則可以將特徵向量轉化爲生成圖像,它側重於學習生成能力。流模型是從一個簡單的分佈開始,通過一系列可逆的轉換函數將分佈轉化成目標分佈。擴散模型先通過正向過程將噪聲逐漸加入到數據中,然後通過反向過程預測每一步加入的噪聲,通過將噪聲去掉的方式逐漸還原得到無噪聲的圖像,擴散模型本質上是一個馬爾可夫架構,只是其中訓練過程用到了深度學習的BP,但它更屬於數學層面的創新。這也就是爲什麼很多計算機的同學看擴散模型相關的論文會如此費力。

DDPM_1.png

圖1:生成模型的四種類型 [4]

擴散模型中最重要的思想根基是馬爾可夫鏈,它的一個關鍵性質是平穩性。即如果一個概率隨時間變化,那麼再馬爾可夫鏈的作用下,它會趨向於某種平穩分佈,時間越長,分佈越平穩。如圖2所示,當你向一滴水中滴入一滴顏料時,無論你滴在什麼位置,只要時間足夠長,最終顏料都會均勻的分佈在水溶液中。這也就是擴散模型的前向過程。

DDPM_2.png

圖2:顏料分子在水溶液中的擴散過程

如果我們能夠在擴散的過程顏料分子的位置、移動速度、方向等移動屬性。那麼也可以根據正向過程的保存的移動屬性從一杯被溶解了顏料的水中反推顏料的滴入位置。這邊是擴散模型的反向過程。記錄移動屬性的快照便是我們要訓練的模型。

2. 擴散模型

在這一部分我們將集中介紹擴散模型的數學原理以及推導的幾個重要性質,因爲推導過程涉及大量的數學知識但是對理解擴散模型本身思想並無太大幫助,所以這裏我會省去推導的過程而直接給出結論。但是我也會給出推導過程的出處,對其中的推導過程比較感興趣的請自行查看。

2.1 計算原理

擴散模型簡單的講就是通過神經網絡學習從純噪聲數據逐漸對數據進行去噪的過程,它包含兩個步驟,如圖3:

DDPM_3.png

圖3:DDPM的前向加噪和後向去噪過程

2.1.1 前向過程

2.1.2 後向過程

2.1.3 目標函數

那麼問題來了,我們究竟使用什麼樣的優化目標才能比較好的預測高斯噪聲的分佈呢?一個比較複雜的方式是使用變分自編碼器的最大化證據下界(Evidence Lower Bound, ELBO)的思想來推導,如式(6),推導詳細過程見論文[11]的式(47)到式(58),這裏主要用到了貝葉斯定理和琴生不等式。

\(\begin{aligned} \mathcal L & = - \log p(\boldsymbol x) \\ & = - \log \int \frac{p_\theta(\boldsymbol x_{0:T})q(\boldsymbol x_{1:T} | \boldsymbol x_0)}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} d \boldsymbol x_{1:T} \\ & \leq - \mathbb E_{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} \left[ \frac{p_\theta(\boldsymbol x_{0:T})}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)}\right] \\ & = - \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{\text {重構項}} + \underbrace{D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)}_{\text {先驗匹配項}} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]}_{\text {去噪匹配項}} \end{aligned} \tag6\)

式(6)的推導細節並不重要,我們需要重點關注的是它的最終等式的三個組成部分,下面我們分別介紹它們:

DDPM_4.png

圖4:擴散模型的去噪匹配項在每一步都要擬合噪音的真實後驗分佈和估計分佈

真實後驗分佈可以使用貝葉斯定理進行推導,最終結果如式(8),推導過程見論文[11]的式(71)到式(84)。

\(\begin{aligned} q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) & = \frac{q(\boldsymbol x_{t} | \boldsymbol x_{t-1}, \boldsymbol x_0) q(\boldsymbol x_{t-1} | \boldsymbol x_0)}{q(\boldsymbol x_{t} | \boldsymbol x_0)} \ & \propto \mathcal N \left( \boldsymbol x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}t}, \frac{(1 - \alpha_t)(1 - \bar{\alpha}{t-1})}{1 - \bar{\alpha}t} \mathbf I \right) \ & = \mathcal N(\boldsymbol x{t-1}; \mu_q(\boldsymbol x_t, \boldsymbol x_0), \Sigma_q(t)) \end{aligned} \tag8 \)

\(p{\boldsymbol{\theta}}\left(\boldsymbol{x}{t-1} \mid \boldsymbol{x}t\right) = \mathcal N(\boldsymbol x{t-1}; \mu\theta(\boldsymbol x_t, t), \Sigma_q(t)) \tag9\)

\( \begin{aligned} & D_\text{KL}(\mathcal N(\boldsymbol x; \boldsymbol \mu_x, \boldsymbol \Sigma_x), \mathcal N(\boldsymbol y; \boldsymbol \mu_y, \boldsymbol \Sigma_y) \ = & \frac{1}{2}\left[ \log \frac{|\boldsymbol \Sigma_x|}{|\boldsymbol \Sigma_y|} - d + \text{tr}(\boldsymbol \Sigma_y^{-1} \boldsymbol \Sigma_x) + (\boldsymbol \mu_y - \boldsymbol \mu_x)^\intercal \boldsymbol \sigma_y^{-1}(\boldsymbol \mu_y - \boldsymbol \mu_x)\right]) \end{aligned} \tag{10} \)

\( \begin{aligned} \mathop{\arg\min}\theta D\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) = \mathop{\arg\min}\theta \frac{1}{2\sigma_q^2(t)} \left[|\boldsymbol \mu\theta (\boldsymbol x_t, \boldsymbol x_0) - \boldsymbol \mu_q(\boldsymbol x_t, t) |^2_2\right] \ \end{aligned} \tag{11} \)

2.1.4 模型訓練

雖然上面我們介紹了很多內容,並給出了大量公式,但得益於推導出的幾個重要性質,擴散模型的訓練並不複雜,它的訓練僞代碼見算法1。

DDPM_ag1.png

2.1.5 樣本生成

DDPM_ag2.png

2.2 算法實現

2.2.1模型結構

DDPM在預測施加的噪聲時,它的輸入是施加噪聲之後的圖像,預測內容是和輸入圖像相同尺寸的噪聲,所以它可以看做一個Img2Img的任務。DDPM選擇了U-Net[9]作爲噪聲預測的模型結構。U-Net是一個U形的網絡結構,它由編碼器,解碼器以及編碼器和解碼器之間的跨層連接(殘差連接)組成。其中編碼器將圖像降採樣成一個特徵,解碼器將這個特徵上採樣爲目標噪聲,跨層連接用於拼接編碼器和解碼器之間的特徵。

DDPM_5.png

圖5:U-Net的網絡結構

下面我們介紹DDPM的模型結構的重要組件。首先在U-Net的卷積部分,DDPM使用了寬殘差網絡(Wide Residual Network,WRN)[12]作爲核心結構,WRN是一個比標準殘差網絡層數更少,但是通道數更多的網絡結構。也有作者復現發現ConvNeXT作爲基礎結構會取得非常顯著的效果提升[13,14]。這裏我們可以根據訓練資源靈活的調整卷積結構以及具體的層數等超參。因爲我們在擴散過程的整個流程中都共享同一套參數,爲了區分不同的時間片,作者借鑑了Transformer [15]的位置編碼的思想,採用了正弦位置嵌入對時間$t$進行了編碼,這使得模型在預測噪聲時知道它預測的是批次中分別是哪個時間片添加的噪聲。在卷積層之間,DDPM添加了一個注意力層。這裏我們可以使用Transformer中提出的自注意力機制或是多頭自注意力機制。[13]則提出了一個線性注意力機制的模塊,它的特點是消耗的時間以及佔用的內存和序列長度是線性相關的,對比傳統注意力機制的平方相關要高效很多。在進行歸一化時,DDPM選擇了組歸一化(Group Normalization,GN)[16]。最後,對於U-Net中的降採樣和上採樣操作,DDPM分別選擇了步長爲2的卷積以及反捲積。

確定了這些組件,我們便可以搭建用於DDPM的U-Net的模型了。從第2.1節的介紹我們知道,模型的輸入爲形狀爲(batch_size, num_channels, height, width)的噪聲圖像和形狀爲(batch_size,1)的噪聲水平,返回的是形狀爲(batch_size, num_channels, height, width)的預測噪聲,我們搭建的用於噪聲預測的模型結構如下:

  1. 首先在噪聲圖像\( \boldsymbol x_0\)上應用卷積層,併爲噪聲水平$t$計算時間嵌入;
  2. 接下來是降採樣階段。採用的模型結構依次是兩個卷積(WRNS或是ConvNeXT)+GN+Attention+降採樣層;
  3. 在網絡的最中間,依次是卷積層+Attention+卷積層;
  4. 接下來是上採樣階段。它首先會使用Short-cut拼接來自降採樣中同樣尺寸的卷積,再之後是兩個卷積+GN+Attention+上採樣層。
  5. 最後是使用WRNS或是ConvNeXT作爲輸出層的卷積。

U-Net類的forword函數如下面代碼片段所示,完整的實現代碼參照[3]。

def forward(self, x, time):
    x = self.init_conv(x)
    t = self.time_mlp(time) if exists(self.time_mlp) else None
    h = []
    # downsample
    for block1, block2, attn, downsample in self.downs:
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        h.append(x)
        x = downsample(x)
    # bottleneck
    x = self.mid_block1(x, t)
    x = self.mid_attn(x)
    x = self.mid_block2(x, t)
    # upsample
    for block1, block2, attn, upsample in self.ups:
        x = torch.cat((x, h.pop()), dim=1)
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        x = upsample(x)
    return self.final_conv(x)


2.2.2 前向加噪

DDPM_6.png

圖6:一張圖依次經過0次,50次,100次,150次以及199次加噪後的效果圖

根據式(14)我們知道,擴散模型的損失函數計算的是兩張圖像的相似性,因此我們可以選擇使用迴歸算法的所有損失函數,以MSE爲例,前向過程的核心代碼如下面代碼片段。

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
   	# 1. 根據時刻t計算隨機噪聲分佈,並對圖像x_start進行加噪
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    # 2. 根據噪聲圖像以及時刻t,預測添加的噪聲
    predicted_noise = denoise_model(x_noisy, t)
    # 3. 對比添加的噪聲和預測的噪聲的相似性
    loss = F.mse_loss(noise, predicted_noise)
    return loss


2.2.3 樣本生成

根據2.1.5節介紹的樣本生成流程,它的核心代碼片段所示,關於這段代碼的講解我通過註釋添加到了代碼片段中。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    # 使用式(13)計算模型的均值
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
    if t_index == 0:
        return model_mean
    else:
      	# 獲取保存的方差
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # 算法2的第4行
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# 算法2的流程,但是我們保存了所有中間樣本
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs


最後我們看下在人臉圖像數據集下訓練的模型,一批隨機噪聲經過逐漸去噪變成人臉圖像的示例。

DDPM_7.gif

圖7:擴散模型由隨機噪聲通過去噪逐漸生成人臉圖像

3. 總結

這裏我們以DDPM爲例介紹了另一個派系的生成算法:擴散模型。擴散模型是一個基於馬爾可夫鏈的數學模型,它通過預測每個時間片添加的噪聲來進行模型的訓練。作爲近日來引發熱烈討論的ControlNet, Stable Diffusion等模型的底層算法,我們十分有必要對其有所瞭解。DDPM的實現並不複雜,這得益於大量數學界大佬通過大量的數學推導將整個擴散過程和反向去噪過程進行了精彩的化簡,這纔有了DDPM的大道至簡的實現。DDPM作爲一個擴散模型的基石算法,它有着很多早期算法的共同問題:

  1. 採樣速度慢:DDPM的去噪是從時刻$T$到時刻$1$的一個完整的馬爾可夫鏈的計算,尤其是DDPM還需要一個比較大的$T$才能保證比較好的效果,這就導致了DDPM的採樣過程註定是非常慢的;
  2. 生成效果差:DDPM的效果並不能說是非常好,尤其是對於高分辨率圖像的生成。這一方面是因爲它的計算速度限制了它擴展到更大的模型;另一方面它的設計還有一些問題,例如逐像素的計算損失並使用相同權值而忽略圖像中的主體並不是非常好的策略。
  3. 內容不可控:我們可以看出,DDPM生成的內容完全還是取決於它的訓練集。它並沒有引入一些先驗條件,因此並不能通過控制圖像中的細節來生成我們制定的內容。

我們現在已經知道,DDPM的這些問題已大幅得到改善,現在基於擴散模型生成的圖像已經達到甚至超過人類多數的畫師的效果,我也會在之後逐漸給出這些優化方案的講解。

Reference

[1] Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International Conference on Machine Learning. PMLR, 2015.

[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

[3] https://huggingface.co/blog/annotated-diffusion

[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification

[5] https://openai.com/blog/generative-models/

[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." International Conference on Machine Learning. PMLR, 2021.

[7] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).

[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. "Reducing the dimensionality of data with neural networks." science 313.5786 (2006): 504-507.

[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.

[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

[11] Luo, Calvin. "Understanding diffusion models: A unified perspective." arXiv preprint arXiv:2208.11970 (2022).

[12] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).

[13] https://github.com/lucidrains/denoising-diffusion-pytorch

[14] Liu, Zhuang, et al. "A convnet for the 2020s." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

[15] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[16] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章