交替方向乘子法(ADMM)的數學基礎

交替方向乘子法(ADMM)

網上的一些資料根本就沒有把ADMM的來龍去脈說清楚,發現只是一個地方簡單寫了一下流程,別的地方就各種抄,共軛函數,對偶梯度上升什麼的,都沒講清楚,給跪了。下面我來講講在機器學習中用得很多的ADMM方法到底是何方神聖。

共軛函數

給定函數f:RnRf: \mathbb{R}^{n} \rightarrow \mathbb{R},那麼函數
f(y)=maxx(yTxf(x))f^{*}(y)=\max _{x}( y^{T} x-f(x))
就叫做它的共軛函數。其實一個更直觀的理解是:對一個固定的yy,將yTxy^Tx看成是一條斜率爲yy的直線,它和f(x)f(x)關於xx的距離的最大值,就是f(y)f^*(y)百度百科有關於這個直觀說法的一個解釋,看看就明白了。

關於共軛函數有幾點重要的說明:

  • 不管ff凸不凸,它的共軛函數總是一個凸函數。

  • 如果ff是閉凸的(閉指定義域是閉的),那麼f=ff^{**}=f

  • 如果ff是嚴格凸的,那麼
    f(y)=argminz(f(z)yTz)\nabla f^{*}(y)=\underset{z}{\operatorname{argmin}} (f(z)-y^{T} z)

  • 共軛總是頻頻出現在對偶規劃中,因爲極小問題總是容易湊出一個共軛:f(y)=minx(f(x)yTx)-f^{*}(y)=\min _{x} (f(x)-y^{T} x)

關於ff^*的凸性,這篇博客給了一個比較直觀的圖示說明,下面我從數學上,不太嚴格地做個簡單證明。以一維的情況說明。

假設ff是一個凸函數,下面都不妨考慮函數的最值都不再邊界處取到。maxx(yxf(x))\max _{x} (yx-f(x))的極值點在fx(x)=yf'_x(x)=y處取到,定義g:=(fx)1g:=(f'_x)^{-1},那麼x=g(y)x=g(y)可能會是一堆點。則有
f(y)=maxx(yxf(x))=yg(x)f(g(y))f^*(y)=\max _{x}(yx-f(x))=yg(x)-f(g(y)) 進而
(f(y))y=g(y)+ygy(y)fx(g(y))gy(y)=g(y)(f^*(y))'_y = g(y)+yg'_y(y)-f'_x(g(y))g'_y(y) = g(y) 那麼
(f(y))yy=gy(y)=((fx)1)y0(f^*(y))''_yy=g'_y(y)=((f'_x)^{-1})'_y \geq 0 故而,ff^*是凸的。

關於f=ff^{**}=f,也是容易證明的。我們假設ff是閉凸的。maxy(zyf(y))\max _y(zy-f^*(y))的值在g(y)=zg(y)=z處取到,那麼
f(z)=maxy(zyf(y))=f(g(g1(z)))=f(z)f^{**}(z) = \max_y(zy-f^{*}(y))=f(g(g^{-1}(z)))=f(z)
zz換成xx就是f(x)f(x)

第三條非常重要,它說明了共軛函數的梯度,其實就是共軛函數取到極大值對應的xx值,它從(f(y))y=g(y)(f*(y))'_y = g(y)就可以看出來。

對偶梯度上升法

有了上知識的鋪墊,我們就可以說清楚對偶上升方法了。以考慮等式約束問題爲例(一般約束問題也是類似的流程),假設f(x)f(x)是嚴格凸的,我們考慮問題:
minxf(x) subject to Ax=b\min _{x} f(x) \text { subject to } A x=b 它的拉格朗日對偶問題是:
maxuminx(f(x)+uT(Axb))\max _{u}\min _{x} (f(x)+u^T(Ax-b))
有理論表明,若原問題和對偶問題滿足強對偶條件,即對偶函數關於uu的最大值等價於原優化問題關於xx的最小。那麼原問題和對偶問題對於xx是同解的。也就是說只要找到使得對偶問題對應最大的uu,其對應的xx就是原優化問題的解,那麼我們就解決了原始優化問題。

所以,下面我們來求解這個對偶問題。先把和xx無關的變量提出minx\min _x,再想辦法湊出ff^*,因爲我們要用到對偶的性質。
KaTeX parse error: No such environment: split at position 7: \begin{̲s̲p̲l̲i̲t̲}̲ \max _u \min…

那麼對偶問題就成了 maxuf(ATu)bTu\max _{u}-f^{*}\left(-A^{T} u\right)-b^{T} u
這裏ff^*ff的共軛,這裏maxu\max _u後面不加括號,表示它管着下面的所有,下同,不再重述。定義g(u)=f(ATu)bTug(u)=-f^{*}\left(-A^{T} u\right)-b^{T} u,我們希望能極大化g(u)g(u),一個簡單的想法是沿着g(u)g(u)梯度上升的方向去走。注意到,
g(u)=Af(ATu)b\partial g(u)=A \partial f^{*}\left(-A^{T} u\right)-b
因此,利用共軛的性質,
g(u)=Axb where xargminzf(z)+uTAz\partial g(u)=A x-b \text { where } x \in \underset{z}{\operatorname{argmin}} f(z)+u^{T} A z
因爲ff是嚴格凸的,ff^*是可微的,那麼,就有了所謂的對偶梯度上升方法。從一個對偶初值u(0)u^{(0)}開始,重複以下過程:
x(k)=argminxf(x)+(u(k1))TAxu(k)=u(k1)+tk(Ax(k)b)\begin{aligned} &x^{(k)}=\underset{x}{\operatorname{argmin}} f(x)+\left(u^{(k-1)}\right)^{T} A x\\ &u^{(k)}=u^{(k-1)}+t_{k}\left(A x^{(k)}-b\right) \end{aligned}
這裏的步長tkt_k使用標準的方式選取的。近端梯度和加速可以應用到這個過程中進行優化。

交替方向乘子法

交替方向乘子法(ADMM)是一種求解具有可分離的凸優化問題的重要方法,由於處理速度快,收斂性能好,ADMM算法在統計學習、機器學習等領域有着廣泛應用。ADMM算法一般用於解決如下的凸優化問題:
minx,yf(x)+g(y) subject to Ax+By=c\min _{x, y} f(x)+g(y) \text { subject to } A x+B y=c
其中的ffgg都是凸函數。

它的增廣拉格朗日函數如下:
Lp(x,y,λ)=f(x)+g(y)+λT(Ax+Byc)+(ρ/2)Ax+Byc22,ρ>0L_{p}(x, y, \lambda)=f(x)+g(y)+\lambda^{T}(A x+B y-c)+(\rho / 2)\|A x+B y-c\|_{2}^{2}, \rho>0

ADMM算法求解思想和推導同梯度上升法,最後重複迭代以下過程:
xk+1:=argminxLp(x,y,λ)xk+1:=argminyLp(x,y,λ)λk+1:=λk+ρ(Axk+1+Byk+1c)\begin{aligned} x^{k+1} &:=\arg \min _x L_{p}(x, y, \lambda) \\ x^{k+1} &:=\arg \min _y L_{p}(x, y, \lambda) \\ \lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned} 上述迭代可以進行簡化。

  • 第一步簡化,通過公式a+b22=a22+b22+2aTb\|a+b\|_{2}^{2}=\|a\|_{2}^{2}+\|b\|_{2}^{2}+2 a^{T} b,替換掉拉格朗日函數中的線性項λT(Ax+Byc)\lambda^{T}(A x+B y-c)和二次項ρ/2Ax+Byc22\rho/2\|A x+B y-c\|_{2}^{2},可以得到
    λT(Ax+Byc)+ρ/2Ax+Byc22=ρ/2Ax+Byc+λ/ρ22ρ/2λ/ρ22\lambda^{T}(A x+B y-c)+\rho/2\|A x+B y-c\|_{2}^{2}=\rho / 2\|A x+B y-c+\lambda/\rho\|_{2}^{2}-\rho / 2\|\lambda / \rho\|_{2}^{2}
    那麼ADMM的過程可以化簡如下: xk+1:=argminx(f(x)+ρ/2Ax+Bykc+λk/ρ22yk+1:=argminy(g(y)+ρ/2Axk+1+Byc+λk/ρ22λk+1:=λk+ρ(Axk+1+Byk+1c)\begin{aligned} x^{k+1} &:={\arg \min _x}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\ y^{k+1} &:={\arg \min _y}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\ \lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned}

  • 第二步化簡,零縮放對偶變量u=λ/ρu = \lambda/\rho,於是ADMM過程可化簡爲:
    xk+1:=argmin(f(x)+ρ/2Ax+Bykc+uk22yk+1:=argmin(g(y)+ρ/2Axk+1+Byc+uk22uk+1:=uk+(Axk+1+Byk+1c)\begin{aligned} x^{k+1} &:={\arg \min}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+u^{k}\right\|_{2}^{2}\right.\\ y^{k+1} &:={\arg \min}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+u^{k}\right\|_{2}^{2}\right.\\ u^{k+1} &:=u^{k}+\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned}

ADMM相當於把一個大的問題分成了兩個子問題,縮小了問題的規模,分而治之。

發佈了305 篇原創文章 · 獲贊 515 · 訪問量 129萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章