理解 KL 散度的近似

作者:John Schulman(OpenAI)

譯者:朱小虎 Xiaohu (Neil) Zhu(CSAGI / University AI)

原文鏈接http://joschu.net/blog/kl-approx.html

術語\text{KL} 散度(\text{KL} Divergence); 近似(Approximation); Monte-Carlo 估計(Monte-Carlo estimator)

本文討論 \text{KL} 散度的 Monte-Carlo 近似:

\text{KL}[q,p] = \sum_{x}q(x)\log \frac{q(x)}{p(x)} = E_{x\sim q}[\log \frac{q(x)}{p(x)}]

這解釋了之前使用了一個技巧,針對來自 q中的樣本 x\frac{1}{2}(\log p(x)-\log q(x))^2樣本平均來近似 \text{KL}[q,p],而不是更加標準的 \log \frac{q(x)}{p(x)}. 本文談談爲何該表達是一個 KL 散度的好的估計(儘管有偏 biased),以及如何讓其變得無偏(unbiased)保證其低方差。

我們計算 KL 的選擇取決於對 pq的訪問方式。這裏,我們假設能夠對任意 x計算概率(或者概率密度)p(x)q(x),但是我們不能解析地跑遍 x求和。爲何我們不能解析地計算呢?

  1. 準確地計算此和需要太多計算或者內存
  2. 沒有閉式形式
  3. 我們可以通過僅僅存儲 \text{log-prob}而非整個分佈來簡化代碼。只要 \text{KL}僅僅用來作爲診斷工具這是一個合理的選擇,這也是強化學習中常見情況

最常用的估計求和或者積分的策略是使用 Monte-Carlo 估計。給定樣本 x_1, x_2, \cdots \sim q,我們如何構造好的估計?

一個好的估計是無偏的(即有正確的均值)並且低方差。我們知道一個無偏估計(在從 q中採樣的樣本下)是 \log \frac{q(x)}{p(x)}。但是,它有高方差,因爲它對樣本的一半是負,而 KL 總是爲正。讓我們稱此簡易估計 k_1 = \log \frac{q(x)}{p(x)} = - \log r,其中我們已經定義了比例 r = \frac{p(x)}{q(x)}後面也會多次出現此值。

另一個替代估計有低的方差不過是有偏的,即 \frac{1}{2} (\log \frac{p(x)}{q(x)})^2 = \frac{1}{2} (\log r)^2。我們不妨稱此爲 k_2。直覺上看,k_2看起來更加好因爲每個樣本告訴了我們 pq之間相距多遠,並且總爲正。實驗上看,k_2實際有比 k_1更低的方差,並也有相當低的偏差(bias)。(下面在實驗中給出此點)。

關於估計 k_2爲何有低偏差有一個很好的原因:其期望是一個 f-散度(divergence)。一個 f-散度 被定義爲關於一個凸函數 fD_f(p,q) = \text{E}_{x\sim q} [f(\frac{p(x)}{q(x)})]。KL 散度和其他有名的概率距離均是 f-散度。現在這是關鍵的難以被發現的事實:所有具有可微函數 ff-散度與 \text{KL} 散度當 q接近 p時的二階。也就是說,對一個參數化分佈 p_{\theta}

D_f(p_0,p_\theta) = \frac{f''(1)}{2} \theta^\intercal F \theta + O(\theta^3)

其中 F 是關於 p_\theta的 Fisher 信息矩陣在 p_\theta = p_0的值。

期望 E_q[k_2] = E_q[\frac{1}{2} (\log r)^2]是 f-散度,其中 f(x) = \frac{1}{2}(\log x)^2,而 \text{KL}[q,p]對應於 f(x) = - \log x。易見,兩者均有 f''(1) = 1,所以兩者看起來對 p\approx q有相同的二階距離函數。

是否可以寫出一個 \text{KL} 散度估計既是無偏又是低方差的呢?一般達成低方差的方法是通過一個控制變量。就是說,取 k_1並加上某個期望爲零但是與 k_1負相關的量。保證期望爲零唯一有趣的量是 \frac{p(x)}{q(x)} - 1 = r - 1。所以,對任意的 \lambda,表達式 -\log r + \lambda (r - 1)KL[q,p]的無偏估計。我們可以做一些計算來最小化這個估計的方差,對 \lambda求解。但不幸的是,我們獲得一個表達式,它依賴於 pq並難以解析地計算。

但是,我們可以使用一個更爲簡單的策略來選擇一個好的 \lambda。注意因爲 \log是凹函數,\log(x) \leq x - 1。因此,如果我們令 \lambda = 1,上面的表達式會確保爲正。它度量了 \log(x)和它的切線的豎直距離。這讓我們有了估計 k_3 = (r - 1) - \log r

通過看凸函數和它切平面的差距來度量距離的想法出現在很多領域。這杯稱爲 Bregman 散度並有很多優美性質。

我們可以推廣上面想法來獲得一個好的,總是爲正的對任何 f-散度的估計,大多數明顯是另一個 \text{KL} 散度即 \text{KL}[p,q](注意這裏的 pq調換了次序)。因爲 f是凸函數,並且 E_q[r] = 1,下面是 f-散度的一個估計:f(r) - f'(1)(r-1)。這總是爲正因爲它是 f 和其在 r=1處的距離,並且凸函數在它們的切線上方。現在 \text{KL}[p,q]對應於 f(x) = x \log x,其有 f'(1) = 1,使得我們有了估計 r \log r - (r - 1)

總結一下,我們有下列估計(對樣本 x\sim qr = \frac{p(x)}{q(x)}):

  • \text{KL}[p,q] : r \log r - (r - 1)
  • \text{KL}[q,p] : (r - 1) - \log r

現在我們比對這三個對 \text{KL}[p,q]估計的偏差和方差。假設 q=N(0,1), p=N(0.1,1)。這裏正確的 KL 散度爲 0.005

k bias/true stdev/true
k1 0 20
k2 0.002 1.42
k3 0 1.42

注意 k2 的偏差非常低:爲 0.2%。

現在我們嘗試對大一些的 \text{KL} 散度近似。p=N(1,1)給我們一個真實 \text{KL} 散度爲 0.5

k bias/true stdev/true
k1 0 2
k2 0.25 1.73
k3 0 1.7

這裏,k2 的偏差更大一些。k3 甚至有比 k2 更低的標準差同時還是無偏的,所以它看起來也是在一個嚴格意義上更好的估計。

這裏是我用來產生這些結果的代碼:

import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
    print((k.mean() - truekl) / truekl, k.std() / truekl)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章