作者:John Schulman(OpenAI)
譯者:朱小虎 Xiaohu (Neil) Zhu(CSAGI / University AI)
原文鏈接:http://joschu.net/blog/kl-approx.html
術語: 散度( Divergence); 近似(Approximation); Monte-Carlo 估計(Monte-Carlo estimator)
本文討論 散度的 Monte-Carlo 近似:
這解釋了之前使用了一個技巧,針對來自 中的樣本 以 樣本平均來近似 ,而不是更加標準的 . 本文談談爲何該表達是一個 KL 散度的好的估計(儘管有偏 biased),以及如何讓其變得無偏(unbiased)保證其低方差。
我們計算 KL 的選擇取決於對 和 的訪問方式。這裏,我們假設能夠對任意 計算概率(或者概率密度)和 ,但是我們不能解析地跑遍 求和。爲何我們不能解析地計算呢?
- 準確地計算此和需要太多計算或者內存
- 沒有閉式形式
- 我們可以通過僅僅存儲 而非整個分佈來簡化代碼。只要 僅僅用來作爲診斷工具這是一個合理的選擇,這也是強化學習中常見情況
最常用的估計求和或者積分的策略是使用 Monte-Carlo 估計。給定樣本 ,我們如何構造好的估計?
一個好的估計是無偏的(即有正確的均值)並且低方差。我們知道一個無偏估計(在從 中採樣的樣本下)是 。但是,它有高方差,因爲它對樣本的一半是負,而 KL 總是爲正。讓我們稱此簡易估計 ,其中我們已經定義了比例 後面也會多次出現此值。
另一個替代估計有低的方差不過是有偏的,即 。我們不妨稱此爲 。直覺上看,看起來更加好因爲每個樣本告訴了我們 和 之間相距多遠,並且總爲正。實驗上看,實際有比 更低的方差,並也有相當低的偏差(bias)。(下面在實驗中給出此點)。
關於估計 爲何有低偏差有一個很好的原因:其期望是一個 f-散度(divergence)。一個 f-散度 被定義爲關於一個凸函數 ,。KL 散度和其他有名的概率距離均是 -散度。現在這是關鍵的難以被發現的事實:所有具有可微函數 的 -散度與 散度當 接近 時的二階。也就是說,對一個參數化分佈 ,
其中 是關於 的 Fisher 信息矩陣在 的值。
期望 是 f-散度,其中 ,而 對應於 。易見,兩者均有 ,所以兩者看起來對 有相同的二階距離函數。
是否可以寫出一個 散度估計既是無偏又是低方差的呢?一般達成低方差的方法是通過一個控制變量。就是說,取 並加上某個期望爲零但是與 負相關的量。保證期望爲零唯一有趣的量是 。所以,對任意的 ,表達式 是 的無偏估計。我們可以做一些計算來最小化這個估計的方差,對 求解。但不幸的是,我們獲得一個表達式,它依賴於 和 並難以解析地計算。
但是,我們可以使用一個更爲簡單的策略來選擇一個好的 。注意因爲 是凹函數,。因此,如果我們令 ,上面的表達式會確保爲正。它度量了 和它的切線的豎直距離。這讓我們有了估計 。
通過看凸函數和它切平面的差距來度量距離的想法出現在很多領域。這杯稱爲 Bregman 散度並有很多優美性質。
我們可以推廣上面想法來獲得一個好的,總是爲正的對任何 f-散度的估計,大多數明顯是另一個 散度即 (注意這裏的 和 調換了次序)。因爲 是凸函數,並且 ,下面是 f-散度的一個估計:。這總是爲正因爲它是 f 和其在 處的距離,並且凸函數在它們的切線上方。現在 對應於 ,其有 ,使得我們有了估計 。
總結一下,我們有下列估計(對樣本 和 ):
現在我們比對這三個對 估計的偏差和方差。假設 。這裏正確的 KL 散度爲 。
k | bias/true | stdev/true |
---|---|---|
k1 | 0 | 20 |
k2 | 0.002 | 1.42 |
k3 | 0 | 1.42 |
注意 k2 的偏差非常低:爲 0.2%。
現在我們嘗試對大一些的 散度近似。給我們一個真實 散度爲 。
k | bias/true | stdev/true |
---|---|---|
k1 | 0 | 2 |
k2 | 0.25 | 1.73 |
k3 | 0 | 1.7 |
這裏,k2 的偏差更大一些。k3 甚至有比 k2 更低的標準差同時還是無偏的,所以它看起來也是在一個嚴格意義上更好的估計。
這裏是我用來產生這些結果的代碼:
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
print((k.mean() - truekl) / truekl, k.std() / truekl)