《AE-OT: A NEW GENERATIVE MODEL BASED ON EXTENDED SEMI-DISCRETE OPTIMAL TRANSPORT》中文筆記-2: AE-OT算法

文章信息:

D. An, Y. Guo, N. Lei, Z. Luo, S.-T. Yau, and X. Gu, “AE-OT: A NEW GENERATIVE MODEL BASED ON EX- TENDED SEMI-DISCRETE OPTIMAL TRANSPORT,” 2020, p. 19.

發表於2020年ICLR(International Conference on Learning Representations)

 

AE-OT筆記:

《AE-OT: A NEW GENERATIVE MODEL BASED ON EXTENDED SEMI-DISCRETE OPTIMAL TRANSPORT》中文筆記-1: 總述與簡介

《AE-OT: A NEW GENERATIVE MODEL BASED ON EXTENDED SEMI-DISCRETE OPTIMAL TRANSPORT》中文筆記-2: AE-OT算法

《AE-OT: A NEW GENERATIVE MODEL BASED ON EXTENDED SEMI-DISCRETE OPTIMAL TRANSPORT》中文筆記-3: 實驗與結果

 

在介紹算法時,這篇文章採用的是總-分的結構:首先對AE-OT model進行overview(Overview of AE-OT Model ),之後對AE-OT中三個重要步驟(Semi-Discrete OT Map ,Piece-wise Linear Extension ,SingularSetDetection)進行分別描述.具體如下:

3 COMPUTATIONAL ALGORITHMS

3.1 Overview of AE-OT Model 

(其中,AE: θ和\xi分別是編碼器和解碼器的參數.在latent space, latent code被聚類成了三個模式; OT:不同模式之間的singular set用虛線畫出. 最終,由extended OT映射和解碼器映射的組合即可生成出圖像.)

上圖給出了AE-OT模型.它有兩個major components:

(1)(AE)

訓練自動編碼器對從圖像空間X到潛在空間Z的數據流形進行編碼(fθ),並將數據分佈映射到latent code distribution(什麼是latent code?)

InfoGAN中應該有具體的解釋,但懶得讀了,大概率和latent variable是一個意思:"可解釋的 隱變量c,稱作爲latent code,而我們希望通過約束c與生成數據之間的關係,可以使得c裏面包含有對數據的可解釋的信息,如對MNIST數據,c可以分爲categorical latent code代 來表數字種類信息(0~9),以及continuous latent code來表示傾斜度、筆畫粗細等等。"

之後解碼器g_ \xi再將latent code解碼回數據流形(最終圖像). 

(2)(OT)

計算從噪聲分佈到latent code分佈的最佳傳輸映射(OT maps)T。

  • 首先,Brenier potential能夠通過凸優化的過程被找到( Gu et al. (2016)).這個Brenier勢的梯度是一個半離散(semi-diesrete)的最優傳輸映射, 這個OT映射的目標是訓練樣本latent codes的離散集合.
  • 之後, 傳輸映射被分段線性地擴展到全局連續映射\tilde{T},其中圖像域成爲通過三角剖分上述latent code而獲得的simplicial complex(simplicial complex是什麼意思?簡單組合體??三角剖分是什麼意思?)。
  • 最後,定位source domain(噪聲domain)的singularity set從而避免這些點產生新的樣本,這樣得到的映射纔會是一個符合AE特性的連續映射.

綜上所述,給定一個隨機噪聲x,即可通過g_{\xi} \circ \tilde{T}(x)得到最終的生成圖像.

3.2 Semi-Discrete OT Map 

假設source measureμ是絕對連續的並定義在凸集\Omega \in\mathbb{R}^{D}上;(source measure,源測度的含義是什麼?)

target domain是一個離散集合Y={y1,y2,...,yn},y_i\in\mathbb{R}^{D}

target measure是一個Dirac measure,\nu=\sum_{i=1}^{n} \nu_{i} \delta\left(y-y_{i}\right), i=1,2,...,n,它的total mass和source measure的total mass是相等的,即\mu(\Omega)=\sum_{i=1}^{n} \nu_{i}。(爲什麼source measure和target measure的total mass要相等?target measure是什麼?就是正確的latent code嗎?)

在一個半離散的傳輸映射(semi-discrete transport map)T:Ω-->Y下,引入一個劃分(cell decomposition), \Omega =\cup _{i=1}^{n}W_{i},從而讓每一個cell Wi裏的每一個x都被映射到目標yi, 即T:x\in W_{i} \mapsto y_i;

若每一個元胞Wi的μ測度等於圖像T(Wi)=yi, 即μ(Wi)=\nui,則映射T是測度不變的(測度不變應該是μ(T^{-1}(Wi)=μ(Wi)?那爲什麼能推出來測度不變?)記作T_{\#} \mu=\nu;

代價函數爲:

給定c:  Ω × Y → R, 其中c(x,y)代表將單位質量從x運輸到y的成本(transporting a unit mass from x to y)。則T的總成本爲

\int_{\Omega} c(x, T(x)) d \mu(x)=\sum_{i=1}^{n} \int_{W_{i}} c\left(x, y_{i}\right) d \mu(x)

(爲什麼是對dμ(x)進行積分?μ測度的含義是什麼?)

那麼,semi-discrete最優傳輸映射的目標就是找到一個測度保持映射最小化代價函數

T^{*}:=\arg \min _{T_{\#} \mu=\nu} \int_{\Omega} c(x, T(x)) d \mu(x)

Note: 代價函數這個部分的思想和W-GAN中的Earth-Mover(EM)有些像。

----------------------------------------------------------------------------------------------------------------------------------------------------------------

c(x,y)=1/2*||x-y||^{2},由Brenier's theorem,

semi-discrete OT map被一個piece-wise(PL)凸函數(即Brenier potential(Brenier勢u_{h}))的梯度圖給出.

u_{h}: \Omega \rightarrow \mathbb{R}, u_{h}(x):=\max _{i=1}^{n}\left\{\pi_{h, i}(x)\right\}, \text { where } \pi_{h, i}(x)=\left\langle x, y_{i}\right\rangle+h_{i}

它是一個對應於yi∈Y的超平面。如下圖所示,u_{h}圖(Brenier勢)的投影將Ω分解成了元胞W_{i}(h),每一個W_{i}(h)都是supporting plane\pi_{h, i}(x)的投影。

\sum _{i}h_{i}=0的條件下,高度向量h是下面這個凸能量(convex energy)的唯一的優化器:

E(h)=\int_{0}^{h} \sum_{i=1}^{n} w_{i}(\eta) d \eta_{i}-\sum_{i=1}^{n} h_{i} \nu_{i}

其中,w_{i}(\eta)是Wi(η)的μ-volume.凸能量(convex energy)E(h)能夠通過簡單的梯度下降方法\nabla E(h)=\left(w_{i}(h)-\nu_{i}\right)^{T}求解。

-------------------------------------------------------------------------------------------------------------------------------------------------------------

所以,現在的關鍵是計算每一個元胞Wi(h)的μ-volumew_{i}(\eta),這個能夠通過傳統的蒙特卡洛方法進行估計。

從原測度µ distribution中隨機採N個點,\left\{x_{j}\right\} \sim_{i . i . d .} \mu\forall j \in \mathcal{J},那麼估計的每一個元胞的μ-volume就是

\hat{w}_{i}(h)=\#\left\{j \in \mathcal{J} | x_{j} \in W_{i}(h)\right\} / N

給定xj,我們可以通過i=\arg \max _{i}\left\{\left\langle x_{j}, y_{i}\right\rangle+h_{i}\right\}, i=1,2, \ldots, n得到Wi;

當N足夠大時,\hat{w}_{i}(h)就會收斂到{w}_{i}(h)

這樣,凸能量(convex energy)的梯度就能夠被估計出來:

\nabla E \approx\left(\hat{w}_{i}(h)-\nu_{i}\right)^{T}

一旦梯度被估計出來,就可以通過Adam算法來minimzie凸能量(convex energy)。

上面的蒙特卡洛方法是可以在GPU上並行計算實現的,原因是:x的採樣彼此獨立,並且x所在的單元僅涉及矩陣乘法和排序。

---------------------------------------------------------------------------------------------------------------------------------------------------------------

上面的蒙特卡洛方法的精度和樣本數量平方成反比,但顯然耗時和樣本數量正相關,所以就存在精度和速度之間的矛盾。文章提出了一種策略:當E(h)連着好多次都不降,也就是該提高精度的時候,就增加樣本數量。

上述算法總結起來如下所示:

上述算法總結起來就是用採樣並用蒙特卡羅法得到\hat{w}_{i}(h),再用梯度下降不斷優化\nabla E \approx\left(\hat{w}_{i}(h)-\nu_{i}\right)^{T},在這個過程中,不斷更新h。最終,由Brenier's theorem,最優映射即爲T(\cdot) \leftarrow \nabla\left(\max _{i}\left\langle\cdot, y_{i}\right\rangle+h_{i}\right)

 

3.3 Piece-wise Linear Extension 

半離散的OT map: \nabla u_{h}: \Omega \rightarrow Y,它將所有Ω內的x映射到latent codes,而並不會生成新的樣本。

(還是不理解爲什麼要extend semi-discrete OT map to piecewise linear mapping)u_{h}在source domain的投影引起了Ω的一個cell decomposition,其中每一個cell的μ-volume是\nu_{i},並被映射到對應的yi. 

通過將這些celss用它們的μ-mass來表示:c_{i}:=\int_{W_{i}(h)} x d \mu(x),我們可以得到點對點(point-wise)的映射t: c_{i} \mapsto y_{i}

Cell decomposition的龐加萊式(什麼是龐加萊式?)引入了中心C={ci}的三角剖分

如果W_{i} \cap W_{j} \neq \emptyset,則鏈接{ci和cj}來構成邊[ci,cj]。類似地,如果W_{i_{0}} \cap W_{i_{1}} \cdots \cap W_{i_{k}} \neq \emptyset,則就會得到一個k維的單純形[ci0.ci1,...,cik]. 

所有的這些簡圖構成了C(一個簡單複合體)的三角剖分,記作\mathcal{T}(C),即爲下圖中的綠色三角形。

我們也可用同樣的方法對Y進行三角剖分得到Y的三角剖分\mathcal{T}(Y),,如下圖所示:

一旦從分佈μ中得到一個隨機樣本x,就能夠得到包含x的\mathcal{T}(C)中的單純形(simplex)σ。(simplex σ是什麼?,爲什麼它會在\mathcal{T}(C)中?這裏的單純形σ就是上圖中由綠色構成的多邊形)

假設這個單純形σ有(d+1)個頂點\left\{c_{i_{0}}, c_{i_{1}}, \ldots, c_{i_{d}}\right\},則σ中x的重心座標定義爲:

x=\sum_{k=0}^{d} \lambda_{k} c_{i_{k}},且\sum_{k=0}^{d} \lambda_{k}=1同時所有λk均非負。

之後,在這個piece-wise linear map下由x生成的latent code爲:

\tilde{T}(x)=\sum_{k=0}^{d} \lambda_{k} y_{i_{k}}

下圖就展示出了一個這樣的變換:

 

因爲所有的yi都被用來建立簡單複合體\mathcal{T}(Y),因此不存在被丟掉的模式(爲什麼所有節點都在就不會有模式被丟掉了?會不會有Y中的節點沒有被映射呢?)。

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------

在實際應用中,µ-mass center ci 是通過所有Wi(h)內的蒙特卡羅樣本的平均值來估計的,即:

\hat{c}_{i}=\sum_{x_{j} \in W_{i}} x_{j} / \#\left\{x_{j} \in W_{i}\right\}, \text { where } x_{j} \sim \mu

連接信息\mathcal{T}(C)過於複雜以致於難以在高維空間中construct and store,因此\mathcal{T}(C)並沒有explicitly built;Instead,文章發現包含x的單純形\sigma \in \mathcal{T}(C)如下:

給定一個隨機點x∈Ω,計算並對它與各個中心之間的歐氏距離d\left(x, \hat{c}_{i}\right), i=1,2, \ldots, n進行升序排序。假設前{d+1}項是

\left\{d\left(x, \hat{c}_{i_{0}}\right), d\left(x, \hat{c}_{i_{1}}\right), \ldots, d\left(x, \hat{c}_{i_{d}}\right)\right\},那麼σ就由\left\{\hat{c}_{i_{k}}\right\}構成。重心座標\hat{\lambda}_{i_{k}}可以通過下式進行估計

\hat{\lambda}_{i_{k}}=d^{-1}\left(x, \hat{c}_{i_{k}}\right) / \sum_{k=0}^{d} d^{-1}\left(x, \hat{c}_{i_{k}}\right)

但這樣可能會產生一些supurious,因此爲了克服這個問題,還需要進一步檢測source domain Ω中的singular set。

 

SingularSetDetection

如果目標分佈有多個模式或者是非凸(concave)的,那麼就會存在singular sets\Sigma \subset \Omega,其中 Brenier potential 連續但是並不可微,所以它的梯度圖,即the transport map並不連續。

如上圖所示,source distribution均勻地定義在Ω上,target經驗分佈有兩個modes。在Brenier potentialu_{h}中有一個脊(紅色線)。它地投影就是singular set Σ(即Ω中的紅線)。Ω\Σ存在兩個連接的組成部分,它們中的每一個都映射到了一個模式上。Σ由一維切面組成。

如果W_{i}(h) \cap W_{j}(h) \subset \Sigma,則u_{h}上的兩個支撐面\pi_{h, i} \text { and } \pi_{h, j}之間的夾角就會非常大。

因此,在Brenier potential圖中,我們選擇了夾角大於給定閾值的小平面對,它們相交的投影給出了singular set Σ的一維位置。

這樣,在生成新的latent code的時候,如果一個隨機採樣就在Σ附近,那麼通過這篇文章的方法就能夠拋棄這個樣本,從而避免mode mixture現象(但傳統生成過程就沒有這個機制,就會生成到不同模式之間的gaps,從而產生mode miture問題)。

------------------------------------------------------------------------------------------------------------------------------------------------

在上圖中的左側圖,給定extended OT map\tilde{T}(x),有一些多邊形的邊與singular set(紅色線)相交,這意味着不同的頂點分屬於不同的模式。如果採樣得到的x正好在這種多面體(紅色虛線構成的多面體上),那麼就abandon這個x(這在上圖中的右側也有表示)。

具體而言,給定x,通過下面的方法能夠檢測出來它是不是屬於singular set:

檢查\pi_{i_{0}} \text { and } \pi_{i_{k}}, k=1,2, \ldots, d之間的角度\theta_{i_{k}}\theta_{i_{k}}=\left\langle y_{i_{0}}, y_{i_{k}}\right\rangle /\left\|y_{i_{0}}\right\| \cdot\left\|y_{i_{k}}\right\|

如果所有的\theta_{i_{k}}都比閾值\hat{\theta}大,則判定x屬於singular set並abandon it;

 

 

或者,文章選擇了一個滿足\theta_{i_{k}} \leq \hat{\theta}的子集\left\{\pi_{i_{k}}\right\},記作\left\{\pi_{\hat{i}_{k}}, k=0,1, \ldots, d_{1}\right\}

接下來就可以計算\lambda_{k}=d^{-1}\left(x, \hat{c}_{\hat{i}_{k}}\right) / \sum_{j=0}^{d_{1}} d^{-1}\left(x, \hat{c}_{\hat{i}_{j}}\right) \text { and } \tilde{T}(x)=\sum_{k=0}^{d_{1}} \lambda_{k} T\left(\hat{c}_{\hat{i}_{k}}\right)

顯然,\tilde{T}(\cdot )對於latent code比較密集的區域平滑了離散函數T(.); 對於latent code比較稀疏的地方則保留了T()的離散型。這樣,就能夠避免生成spurious latent code並且提高重建質量。生成新的latent code的辦法如下所示:

 

總結

本部分先是總述了AE-OT Model,其中AE部分較爲常規,OT部分是重點也是難點。所以,接下來,對OT的三個步驟分別進行了描述:

  • Semi-Discrete OT Map: 由原有定理+梯度下降法完成;
  • Piece-wise Linear Extension:由中心座標+公式計算完成;
  • SingularSetDetection:由夾角閾值判斷完成。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章