交替方向乘子法(ADMM)
網上的一些資料根本就沒有把ADMM的來龍去脈說清楚,發現只是一個地方簡單寫了一下流程,別的地方就各種抄,共軛函數,對偶梯度上升什麼的,都沒講清楚,給跪了。下面我來講講在機器學習中用得很多的ADMM方法到底是何方神聖。
共軛函數
給定函數f : R n → R f: \mathbb{R}^{n} \rightarrow \mathbb{R} f : R n → R ,那麼函數
f ∗ ( y ) = max x ( y T x − f ( x ) ) f^{*}(y)=\max _{x}( y^{T} x-f(x)) f ∗ ( y ) = x max ( y T x − f ( x ) )
就叫做它的共軛函數。其實一個更直觀的理解是:對一個固定的y y y ,將y T x y^Tx y T x 看成是一條斜率爲y y y 的直線,它和f ( x ) f(x) f ( x ) 關於x x x 的距離的最大值,就是f ∗ ( y ) f^*(y) f ∗ ( y ) 。百度百科 有關於這個直觀說法的一個解釋,看看就明白了。
關於共軛函數有幾點重要的說明:
不管f f f 凸不凸,它的共軛函數總是一個凸函數。
如果f f f 是閉凸的(閉指定義域是閉的),那麼f ∗ ∗ = f f^{**}=f f ∗ ∗ = f 。
如果f f f 是嚴格凸的,那麼
∇ f ∗ ( y ) = argmin z ( f ( z ) − y T z ) \nabla f^{*}(y)=\underset{z}{\operatorname{argmin}} (f(z)-y^{T} z) ∇ f ∗ ( y ) = z a r g m i n ( f ( z ) − y T z )
共軛總是頻頻出現在對偶規劃中,因爲極小問題總是容易湊出一個共軛:− f ∗ ( y ) = min x ( f ( x ) − y T x ) -f^{*}(y)=\min _{x} (f(x)-y^{T} x) − f ∗ ( y ) = min x ( f ( x ) − y T x ) 。
關於f ∗ f^* f ∗ 的凸性,這篇博客 給了一個比較直觀的圖示說明,下面我從數學上,不太嚴格地做個簡單證明。以一維的情況說明。
假設f f f 是一個凸函數,下面都不妨考慮函數的最值都不再邊界處取到。max x ( y x − f ( x ) ) \max _{x} (yx-f(x)) max x ( y x − f ( x ) ) 的極值點在f x ′ ( x ) = y f'_x(x)=y f x ′ ( x ) = y 處取到,定義g : = ( f x ′ ) − 1 g:=(f'_x)^{-1} g : = ( f x ′ ) − 1 ,那麼x = g ( y ) x=g(y) x = g ( y ) 可能會是一堆點。則有
f ∗ ( y ) = max x ( y x − f ( x ) ) = y g ( x ) − f ( g ( y ) ) f^*(y)=\max _{x}(yx-f(x))=yg(x)-f(g(y)) f ∗ ( y ) = x max ( y x − f ( x ) ) = y g ( x ) − f ( g ( y ) ) 進而
( f ∗ ( y ) ) y ′ = g ( y ) + y g y ′ ( y ) − f x ′ ( g ( y ) ) g y ′ ( y ) = g ( y ) (f^*(y))'_y = g(y)+yg'_y(y)-f'_x(g(y))g'_y(y) = g(y) ( f ∗ ( y ) ) y ′ = g ( y ) + y g y ′ ( y ) − f x ′ ( g ( y ) ) g y ′ ( y ) = g ( y ) 那麼
( f ∗ ( y ) ) y ′ ′ y = g y ′ ( y ) = ( ( f x ′ ) − 1 ) y ′ ≥ 0 (f^*(y))''_yy=g'_y(y)=((f'_x)^{-1})'_y \geq 0 ( f ∗ ( y ) ) y ′ ′ y = g y ′ ( y ) = ( ( f x ′ ) − 1 ) y ′ ≥ 0 故而,f ∗ f^* f ∗ 是凸的。
關於f ∗ ∗ = f f^{**}=f f ∗ ∗ = f ,也是容易證明的。我們假設f f f 是閉凸的。max y ( z y − f ∗ ( y ) ) \max _y(zy-f^*(y)) max y ( z y − f ∗ ( y ) ) 的值在g ( y ) = z g(y)=z g ( y ) = z 處取到,那麼
f ∗ ∗ ( z ) = max y ( z y − f ∗ ( y ) ) = f ( g ( g − 1 ( z ) ) ) = f ( z ) f^{**}(z) = \max_y(zy-f^{*}(y))=f(g(g^{-1}(z)))=f(z) f ∗ ∗ ( z ) = y max ( z y − f ∗ ( y ) ) = f ( g ( g − 1 ( z ) ) ) = f ( z )
z z z 換成x x x 就是f ( x ) f(x) f ( x ) 。
第三條非常重要,它說明了共軛函數的梯度,其實就是共軛函數取到極大值對應的x x x 值,它從( f ∗ ( y ) ) y ′ = g ( y ) (f*(y))'_y = g(y) ( f ∗ ( y ) ) y ′ = g ( y ) 就可以看出來。
對偶梯度上升法
有了上知識的鋪墊,我們就可以說清楚對偶上升方法了。以考慮等式約束問題爲例(一般約束問題也是類似的流程),假設f ( x ) f(x) f ( x ) 是嚴格凸的,我們考慮問題:
min x f ( x ) subject to A x = b \min _{x} f(x) \text { subject to } A x=b x min f ( x ) subject to A x = b 它的拉格朗日對偶問題是:
max u min x ( f ( x ) + u T ( A x − b ) ) \max _{u}\min _{x} (f(x)+u^T(Ax-b)) u max x min ( f ( x ) + u T ( A x − b ) )
有理論表明,若原問題和對偶問題滿足強對偶條件,即對偶函數關於u u u 的最大值等價於原優化問題關於x x x 的最小。那麼原問題和對偶問題對於x x x 是同解的。也就是說只要找到使得對偶問題對應最大的u u u ,其對應的x x x 就是原優化問題的解,那麼我們就解決了原始優化問題。
所以,下面我們來求解這個對偶問題。先把和x x x 無關的變量提出min x \min _x min x ,再想辦法湊出f ∗ f^* f ∗ ,因爲我們要用到對偶的性質。
KaTeX parse error: No such environment: split at position 7: \begin{̲s̲p̲l̲i̲t̲}̲
\max _u \min…
那麼對偶問題就成了 max u − f ∗ ( − A T u ) − b T u \max _{u}-f^{*}\left(-A^{T} u\right)-b^{T} u u max − f ∗ ( − A T u ) − b T u
這裏f ∗ f^* f ∗ 是f f f 的共軛,這裏max u \max _u max u 後面不加括號,表示它管着下面的所有,下同,不再重述。定義g ( u ) = − f ∗ ( − A T u ) − b T u g(u)=-f^{*}\left(-A^{T} u\right)-b^{T} u g ( u ) = − f ∗ ( − A T u ) − b T u ,我們希望能極大化g ( u ) g(u) g ( u ) ,一個簡單的想法是沿着g ( u ) g(u) g ( u ) 梯度上升的方向去走。注意到,
∂ g ( u ) = A ∂ f ∗ ( − A T u ) − b \partial g(u)=A \partial f^{*}\left(-A^{T} u\right)-b ∂ g ( u ) = A ∂ f ∗ ( − A T u ) − b
因此,利用共軛的性質,
∂ g ( u ) = A x − b where x ∈ argmin z f ( z ) + u T A z \partial g(u)=A x-b \text { where } x \in \underset{z}{\operatorname{argmin}} f(z)+u^{T} A z ∂ g ( u ) = A x − b where x ∈ z a r g m i n f ( z ) + u T A z
因爲f f f 是嚴格凸的,f ∗ f^* f ∗ 是可微的,那麼,就有了所謂的對偶梯度上升方法。從一個對偶初值u ( 0 ) u^{(0)} u ( 0 ) 開始,重複以下過程:
x ( k ) = argmin x f ( x ) + ( u ( k − 1 ) ) T A x u ( k ) = u ( k − 1 ) + t k ( A x ( k ) − b ) \begin{aligned}
&x^{(k)}=\underset{x}{\operatorname{argmin}} f(x)+\left(u^{(k-1)}\right)^{T} A x\\
&u^{(k)}=u^{(k-1)}+t_{k}\left(A x^{(k)}-b\right)
\end{aligned} x ( k ) = x a r g m i n f ( x ) + ( u ( k − 1 ) ) T A x u ( k ) = u ( k − 1 ) + t k ( A x ( k ) − b )
這裏的步長t k t_k t k 使用標準的方式選取的。近端梯度和加速可以應用到這個過程中進行優化。
交替方向乘子法
交替方向乘子法(ADMM)是一種求解具有可分離的凸優化問題的重要方法,由於處理速度快,收斂性能好,ADMM算法在統計學習、機器學習等領域有着廣泛應用。ADMM算法一般用於解決如下的凸優化問題:
min x , y f ( x ) + g ( y ) subject to A x + B y = c \min _{x, y} f(x)+g(y) \text { subject to } A x+B y=c x , y min f ( x ) + g ( y ) subject to A x + B y = c
其中的f f f 和g g g 都是凸函數。
它的增廣拉格朗日函數如下:
L p ( x , y , λ ) = f ( x ) + g ( y ) + λ T ( A x + B y − c ) + ( ρ / 2 ) ∥ A x + B y − c ∥ 2 2 , ρ > 0 L_{p}(x, y, \lambda)=f(x)+g(y)+\lambda^{T}(A x+B y-c)+(\rho / 2)\|A x+B y-c\|_{2}^{2}, \rho>0 L p ( x , y , λ ) = f ( x ) + g ( y ) + λ T ( A x + B y − c ) + ( ρ / 2 ) ∥ A x + B y − c ∥ 2 2 , ρ > 0
ADMM算法求解思想和推導同梯度上升法,最後重複迭代以下過程:
x k + 1 : = arg min x L p ( x , y , λ ) x k + 1 : = arg min y L p ( x , y , λ ) λ k + 1 : = λ k + ρ ( A x k + 1 + B y k + 1 − c ) \begin{aligned}
x^{k+1} &:=\arg \min _x L_{p}(x, y, \lambda) \\
x^{k+1} &:=\arg \min _y L_{p}(x, y, \lambda) \\
\lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right)
\end{aligned} x k + 1 x k + 1 λ k + 1 : = arg x min L p ( x , y , λ ) : = arg y min L p ( x , y , λ ) : = λ k + ρ ( A x k + 1 + B y k + 1 − c ) 上述迭代可以進行簡化。
第一步簡化,通過公式∥ a + b ∥ 2 2 = ∥ a ∥ 2 2 + ∥ b ∥ 2 2 + 2 a T b \|a+b\|_{2}^{2}=\|a\|_{2}^{2}+\|b\|_{2}^{2}+2 a^{T} b ∥ a + b ∥ 2 2 = ∥ a ∥ 2 2 + ∥ b ∥ 2 2 + 2 a T b ,替換掉拉格朗日函數中的線性項λ T ( A x + B y − c ) \lambda^{T}(A x+B y-c) λ T ( A x + B y − c ) 和二次項ρ / 2 ∥ A x + B y − c ∥ 2 2 \rho/2\|A x+B y-c\|_{2}^{2} ρ / 2 ∥ A x + B y − c ∥ 2 2 ,可以得到
λ T ( A x + B y − c ) + ρ / 2 ∥ A x + B y − c ∥ 2 2 = ρ / 2 ∥ A x + B y − c + λ / ρ ∥ 2 2 − ρ / 2 ∥ λ / ρ ∥ 2 2 \lambda^{T}(A x+B y-c)+\rho/2\|A x+B y-c\|_{2}^{2}=\rho / 2\|A x+B y-c+\lambda/\rho\|_{2}^{2}-\rho / 2\|\lambda / \rho\|_{2}^{2} λ T ( A x + B y − c ) + ρ / 2 ∥ A x + B y − c ∥ 2 2 = ρ / 2 ∥ A x + B y − c + λ / ρ ∥ 2 2 − ρ / 2 ∥ λ / ρ ∥ 2 2
那麼ADMM的過程可以化簡如下: x k + 1 : = arg min x ( f ( x ) + ρ / 2 ∥ A x + B y k − c + λ k / ρ ∥ 2 2 y k + 1 : = arg min y ( g ( y ) + ρ / 2 ∥ A x k + 1 + B y − c + λ k / ρ ∥ 2 2 λ k + 1 : = λ k + ρ ( A x k + 1 + B y k + 1 − c ) \begin{aligned}
x^{k+1} &:={\arg \min _x}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\
y^{k+1} &:={\arg \min _y}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\
\lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right)
\end{aligned} x k + 1 y k + 1 λ k + 1 : = arg x min ( f ( x ) + ρ / 2 ∥ ∥ A x + B y k − c + λ k / ρ ∥ ∥ 2 2 : = arg y min ( g ( y ) + ρ / 2 ∥ ∥ A x k + 1 + B y − c + λ k / ρ ∥ ∥ 2 2 : = λ k + ρ ( A x k + 1 + B y k + 1 − c )
第二步化簡,零縮放對偶變量u = λ / ρ u = \lambda/\rho u = λ / ρ ,於是ADMM過程可化簡爲:
x k + 1 : = arg min ( f ( x ) + ρ / 2 ∥ A x + B y k − c + u k ∥ 2 2 y k + 1 : = arg min ( g ( y ) + ρ / 2 ∥ A x k + 1 + B y − c + u k ∥ 2 2 u k + 1 : = u k + ( A x k + 1 + B y k + 1 − c ) \begin{aligned}
x^{k+1} &:={\arg \min}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+u^{k}\right\|_{2}^{2}\right.\\
y^{k+1} &:={\arg \min}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+u^{k}\right\|_{2}^{2}\right.\\
u^{k+1} &:=u^{k}+\left(A x^{k+1}+B y^{k+1}-c\right)
\end{aligned} x k + 1 y k + 1 u k + 1 : = arg min ( f ( x ) + ρ / 2 ∥ ∥ A x + B y k − c + u k ∥ ∥ 2 2 : = arg min ( g ( y ) + ρ / 2 ∥ ∥ A x k + 1 + B y − c + u k ∥ ∥ 2 2 : = u k + ( A x k + 1 + B y k + 1 − c )
ADMM相當於把一個大的問題分成了兩個子問題,縮小了問題的規模,分而治之。