標題黨一下,順便蹭一下 OpenAI Sora大模型的熱點,主要也是回顧一下擴散模型的原理。
1. 簡單理解擴散模型
簡單理解,擴散模型如下圖所示可以分成兩部分,一個是 forward,另一個是 reverse 過程:
- forward:這是加噪聲的過程,表示爲\(q(X_{0:T})\),即在原圖(假設是\(t_0\)時刻的數據,即\(X_0\))的基礎上分時刻(一般是 T 個時刻)逐步加上噪聲數據,最終得到\(t_T\)時刻的數據\(X_T\)。具體來說我們每次加一點噪聲,可能加了 200 次噪聲後得到服從正態分佈的隱變量,即\(X_t=X_0+ z_0+ z_1+...+ z_{t-1}\)每個時刻加的噪聲會作爲標籤用來在逆向過程的時候訓練模型。
- reverse:這很好理解,其實就是去噪過程,是\(q(X_{0:T})\)的逆過程,表示爲\(P_\theta(X_{0:T})\),即逐步對數據\(X_T\)逆向地去噪,儘可能還原得到原圖像。逆向過程其實就是需要訓練一個模型來預測每個時刻的噪聲 \(z_T\),從而得到上一時刻的圖像,通過迭代多次得到原始圖像,即\(X_0=X_t-z_t-z_{t-2}-...-z_1\)。模型訓練會迭代多次,每次的輸入是當前時刻數據\(X_t\),輸出是噪聲\(z_t\),對應標籤數據是\(\overline z_{t-1}\),損失函數是\(mse(z_t,\overline z_{t-1})\)
怎麼理解這兩個過程呢?一種簡單的理解方法是我們可以假設世界上所有圖像都是可以通過加密(就是 forward 過程)表示成隱變量,這些隱變量人眼看上去就是一堆噪聲點。我們可以通過神經網絡模型逐漸把這些噪聲去掉,從而得到對應的原圖(即 reverse 過程)。
2. 前向過程的數學表示
前向過程簡單理解就是不斷加噪聲,加噪聲的特點是越加越多:
- 前期加的噪聲要少一點,這樣是爲了避免加太多噪聲會導致模型不太好學習;
- 而當噪聲量加的足夠多後應該增加噪聲的量,因爲如果還是每次只加一點點,其實差別不大,而且這會導致前向過程太長,那麼對應逆向過程也長,最終會增加計算量。所以噪聲的量會有超參數\(\beta_t\)控制。t 越大,\(\beta_t\)的值也就越大。
那我們可以很自然地知道,t 時刻的圖像應該跟 t-1時刻的圖像和噪聲相關,所以有
其中\(\alpha_t=1-\beta_t\), \(z_1\)是服從 (0,1) 正太分佈的隨機變量。常見的參數設置是\(\beta_t\)從 0.0001 逐漸增加到0.002,所以\(\alpha_t\)對應越來越小,也就是說噪聲的佔比逐漸增大。
我們同樣有\(X_{t-1}=\sqrt{\alpha_{t-1}}X_{t-2}+\sqrt{1-\alpha_{t-1}}z_2\),此時我們有
因爲\(z_1,z_2\)都服從正太分佈,且\(\mathcal{N}(0,\sigma_{1}^{2})+\mathcal{N}(0,\sigma_{2}^{2})\sim\mathcal{N}(0,(\sigma_{1}^{2}+\sigma_{2}^{2}))\),所以公式(2)的括號內的兩項之和得到一個新的服從均值爲 0, 方差是\(\sqrt{(a_{t}(1-\alpha_{t-1})}^2+\sqrt{1-\alpha_{t}}^2=1-\alpha_t\alpha_{t-1}\)的變量\(\tilde z_2\sim\mathcal{N}(0,1-\alpha_t\alpha_{t-1})\)。
我們不斷遞歸能夠得到\(X_t\)和\(X_0\)的關係如下:
其中\(\overline{\alpha}_t=\alpha_t\alpha_{t-1}...\alpha_{1}\), \(\overline{z}_t\)是均值爲 0,方差\(\sigma=1-\overline{\alpha}_t\)的高斯變量, \(z_t\)服從(0,1)正態分佈。可以看到給定0 時刻的圖像數據\(X_0\),我們可以求得任意t時刻的\(\overline{\alpha}_t\)和與之有關的\(\overline z_t\),進而得到對應的\(X_t\)數據,至此前向過程就結束了。
3. 逆向過程的數學表示
3.1 貝葉斯公式求解
擴散模型在應用的時候主要就是 reverse 過程,即給定一組隨機噪聲,通過逐步的還原得到想要的圖像,可以表示爲\(q(X_0|X_t)\)。但是很顯然,我們無法直接從 T 時刻還原得到 0 時刻的數據,所以退而求其次,先求\(q(X_{t-1}|X_t)\)。但是這個也沒那麼容易求得,但是由貝葉斯公式我們可以知道
我們這裏考慮擴散模型訓練過程,我們默認是知道\(X_o\)的,所以有
解釋一下上面的公式:因爲我們可以人爲設置噪聲分佈,所以正向過程中每個時刻的數據也是知道的。例如,假設噪聲\(z\)是服從高斯分佈的,那麼\(X_1=X_0+z\),所以\(q(X_1,X_0)\)是可以知道的,同樣\(q(X_{t-1},X_0),q(X_t,X_0)\)也都是已知的,更一般地,\(q(X_t|X_{t-1},X_0)\)也是已知的。所以上面公式的右邊三項都是已知的,要計算出左邊的結果,就只需要分別求出右邊三項的數學表達式了。
上面三個公式是推導後的結果,省略了億些步驟,我們待會解釋怎麼來的,這裏先簡單解釋一下含義,我們看第一行,\(z\)就是服從正態分佈(均值爲 0,方差爲 1)的變量,爲方便理解其它的可以看成常數,我們知道 \(a+\sqrt{b}z\)會得到均值爲 a,方差爲 b 的服從高斯分佈的變量,那麼第一行最右邊的高斯分佈應該就好理解了。其餘兩行不做贅述,同理。
3.2 高斯分佈概率密度分佈計算
下面公式中左邊的概率分佈其實就是右邊三項概率分佈的計算結果。
我們假設了噪聲數據服從高斯分佈\(\mathcal{N}(\mu,\sigma^2)\),並且知道高斯分佈的概率密度函數是\(exp{(-\frac{1}{2}\frac{(x-\mu)^2}{\sigma^2})}\)。結合上面已經給出的三項的高斯分佈情況,例如
我們可以求得\(q(X_t|X_0)\)的概率密度函數爲\(exp(-\frac{1}{2}\frac{(X_t-\sqrt{\overline{a_t}}X_0)^2}{1-\overline{a_t}})\),其它兩項同理,它們計算後得到的最終的概率密度函數爲:
其中上面公式中\(\beta_t=1-\alpha_t\)。接着我們把上面公式的平方項展開,以\(X_{t-1}\)爲變量(因爲此時我們的目的是求得\(X_{t-1}\))合併同類項整理一下最後可以得到
我們在對比一下\(exp{(-\frac{1}{2}\frac{(x-\mu)^2}{\sigma^2})}=exp(-\frac{1}{2}(\frac{1}{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2}))\)就能知道上面公式中對應的方差和均值:
- 方差
方差等式中的\(\alpha,\beta\)都是與分佈相關的固定值,即給定高斯分佈後,這些變量的值是固定的,所以方差是固定值。
- 均值
均值跟\(X_t\)和\(X_0\)有關 ,但是此時的已知量是\(X_t\),而\(X_0\)是未知的。不過我們可以估計一下\(X_0\)的值,通過前向過程我們知道 \(X_t=\sqrt{\overline{a}_t}X_0+\sqrt{1-\overline{a}_t}z_t\),那麼可以逆向估計一下 \(X_0=\frac{1}{\sqrt{\overline{a}_t}}(X_t-\sqrt{1-\overline{a}_t}z_t)\)。不過需要注意的是,這裏的\(X_0\)只是通過\(X_t\)估算得到的,並不是真實值。所以均值表達式還可以進一步簡化,即
每個時刻的均值和方差的表達式就都有了。不過,每個時刻的方差是個定值,很容易求解,而均值卻跟變量\(z_t\)相關。如果能求解得到\(z_t\),那麼只要給定一個t 時刻的隨機噪聲填滿的圖像\(X_t\),我們就能知道該時刻噪聲的均值和方差,那麼我們就可以通過採樣得到上一時刻的噪聲數據
\(\epsilon\)是服從(0,1)的正態分佈的隨機變量。至此,我們只需要引入神經網絡模型來預測 t 時刻的\(z_t\),即\(z_t=\text{diffusion_model}(x_t)\),模型訓練好後就能得到前一時刻的\(X_{t-1}\)了。
那麼要訓練模型,我們肯定得有標籤和損失函數啊。具體而言:
- \(x_t\)是模型的輸入
- \(z_t\)就是模型的輸出
- 標籤其實就是 forward 過程中每個時刻產生的噪聲數據\(\hat{z}_t\)
- 所以損失函數等於\(\text{loss}=mse(z_t, \hat{z}_t)\)
4. 代碼實現
接下來我們結合代碼來理解一下上述過程。
4.1 前向過程(加噪過程)
給定原始圖像\(X_0\)和加噪的超參數\(\alpha_t=1-\beta_t\)可以求得任意時刻對應的加噪後的數據\(X_t\),即
其中\(\overline{\alpha}_t=\alpha_t\alpha_{t-1}...\alpha_{1}\), \(\overline{z}_t\)是均值爲 0,標準差\(\sigma=\sqrt{1-\overline{\alpha}_t}\)的高斯變量。
下面是具體的代碼實現,首先是與噪聲相關超參數的設置和提前計算:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import transforms
# 定義線性beta時間表
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
# 在給定的時間步數內,線性地從 start 到 end 生成 beta 值
return torch.linspace(start, end, timesteps)
T = 300 # 總的時間步數
betas = linear_beta_schedule(timesteps=T) # β,迭代100個時刻
# 預計算不同的超參數(alpha和beta)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) # 累積乘積
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # 前一個累積乘積
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # alpha的平方根倒數
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # alpha累積乘積的平方根
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) # 1-alpha累積乘積的平方根
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) # 計算後驗分佈q(x_{t-1}|x_t,x_0)的方差
接下來是具體的前向過程的計算,其中get_index_from_list
函數是爲了快速獲得指定 t 時刻對應的超參數的值,支持批量圖像操作。forward_diffusion_sample
則是前向擴散採樣函數。
def get_index_from_list(vals, time_step, x_shape):
"""
返回傳入的值列表vals(如β_t 或者α_t)中特定時刻t的值,同時考慮批量維度。
參數:
vals: 一個張量列表,包含了不同時間步的預計算值。
time_step: 一個包含時間步的張量,其值決定了要從vals中提取哪個時間步的值。
x_shape: 原始輸入數據的形狀,用於確保輸出形狀的一致性。
返回:
一個張量,其形狀與原始輸入數據x_shape相匹配,但是在每個批次中填充了特定時間步的vals值。
"""
batch_size = time_step.shape[0] # 獲取批量大小
out = vals.gather(-1, time_step.cpu()) # 從vals中按照時間步收集對應的值
# 重新塑形爲原始數據的形狀,確保輸出與輸入在除批量外的維度上一致
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(time_step.device)
# 前向擴散採樣函數
def forward_diffusion_sample(x_0, time_step, device="cpu"):
"""
輸入:一個圖像和一個時間步
返回:圖像對應時刻的噪聲版本數據
"""
noise = torch.randn_like(x_0) # 生成和x_0形狀相同的噪聲
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, time_step, x_0.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, time_step, x_0.shape)
# 計算均值和方差
return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(
device
), noise.to(device)
image = Image.open('xiaoxin.jpg').convert('RGB')
img_tensor = transforms.ToTensor()(image)
for idx in range(T):
time_step = torch.Tensor([idx]).type(torch.int64)
img, noise = forward_diffusion_sample(img_tensor, time_step)
plt.imshow(transforms.ToPILImage()(img)) # 繪製加噪圖像
4.2 訓練
我們忽略具體的模型結構細節,先看看訓練流程是怎樣的:
if __name__ == "__main__":
model = SimpleUnet()
T = 300
BATCH_SIZE = 128
epochs = 100
dataloader = load_transformed_dataset(batch_size=BATCH_SIZE)
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")
model.to(device)
optimizer = Adam(model.parameters(), lr=0.001)
for epoch in range(epochs):
for batch_idx, (batch_data, _) in enumerate(dataloader):
optimizer.zero_grad()
# 對一個 batch 內的數據採樣任意時刻的 time_step
t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
x_noisy, noise = forward_diffusion_sample(batch_data, t, device) # 計算得到指定時刻的 加噪後的數據 和 對應的噪聲數據
noise_pred = model(x_noisy, t) # 預測對應時刻的噪聲
loss = F.mse_loss(noise, noise_pred) # 計算噪聲預測的損失值
loss.backward()
optimizer.step()
這裏我們忽略模型架構的具體細節,只需要知道每次模型的計算需要 噪聲圖像(x_noisy
) 和 對應的時刻t
即可。
4.2 逆向過程(去噪採樣過程)
給定某一時刻的數據\(X_t\),該時刻的均值\(\mu\)和方差\(\sigma\)如下
通過對\(\mathcal{N}(\tilde\mu_t,\tilde\sigma_t^2)\)分佈進行採樣得到上一時刻的數據\(X_{t-1}=\tilde\mu_t+\tilde\sigma_t\epsilon\),\(z_t\)是模型訓練收斂後,在給定噪聲圖像和對應時刻 t 後計算得到的噪聲數據,\(\epsilon\)是正態分佈隨機變量。
實現代碼如下:
@torch.no_grad()
def sample_timestep(model, x, t):
"""
使用模型預測圖像中的噪聲,並返回去噪後的圖像。
如果不是最後一個時間步,則在此圖像上應用噪聲。
參數:
model - 預測去噪圖像的模型
x - 當前帶噪聲的圖像張量
t - 當前時間步的索引(整數或者整數型張量)
返回:
去噪後的圖像張量,如果不是最後一步,返回添加了噪聲的圖像張量。
"""
# 從預設列表中獲取當前時間步的beta值
betas_t = get_index_from_list(betas, t, x.shape)
# 獲取當前時間步的累積乘積的平方根的補數
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, x.shape)
# 獲取當前時間步的alpha值的平方根的倒數
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# 調用模型來預測噪聲並去噪(當前圖像 - 噪聲預測)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
# 獲取當前時間步的後驗方差
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
# 如Luis Pereira在YouTube評論中指出的,論文中的時間步t有偏移
return model_mean
else:
# 生成與x形狀相同的隨機噪聲
noise = torch.randn_like(x)
# 返回模型均值加上根據後驗方差縮放的噪聲
return model_mean + torch.sqrt(posterior_variance_t) * noise
for i in reversed(range(0, T)):
t = torch.tensor([i], device='cpu', dtype=torch.long)
img = sample_timestep(model, img, t)
5. 總結
- 前向過程:
給定原始圖像\(X_0\)和加噪的超參數\(\alpha_t=1-\beta_t\)可以求得任意時刻對應的加噪後的數據\(X_t\),即
其中\(\overline{\alpha}_t=\alpha_t\alpha_{t-1}...\alpha_{1}\), \(\overline{z}_t\)是均值爲 0,標準差\(\sigma=\sqrt{1-\overline{\alpha}_t}\)的高斯變量。
- 逆向過程
給定某一時刻的數據\(X_t\),該時刻的均值\(\mu\)和方差\(\sigma\)如下
通過對\(\mathcal{N}(\tilde\mu_t,\tilde\sigma_t^2)\)分佈進行採樣得到上一時刻的數據\(X_{t-1}=\tilde\mu_t+\tilde\sigma_t\epsilon\),\(z_t\)是模型訓練收斂後,在給定噪聲圖像和對應時刻 t 後計算得到的噪聲數據,\(\epsilon\)是正態分佈隨機變量。迭代 t 次後即可得到 0 時刻的圖像了。