用pytorch做簡單的最優化問題

pytorch的自動求導很好用,可以利用它對一些求導困難的問題做一些最優化問題,比如昨天狗菜提了一個問題:

求一個三維點的位置,使得它到一個直線族(三維)的距離之和最小

實際上就是求如下最優化問題
miniAix+Biy+Ciz+DiAi2+Bi2+Ci2+Ci2min\sum_i \frac{|A_ix+B_iy+C_iz+D_i|}{\sqrt{ A^2_i+B^2_i+C^2_i+C^2_i}}

這個含絕對值的求導不太好做,如果交給Pytorch的話就比較容易了,構造一個簡單的例子,即過(1,1,1)的三條直線,顯然,最優解就是(1,1,1):

from torch.autograd import Variable
import torch

#Ax+By+Cz+(-A-B-C)=0
tmp1 = torch.rand(3)
a1,b1,c1 = tmp1
d1 = -sum(tmp1)

tmp1 = torch.rand(3)
a2,b2,c2 = tmp1
d2 = -sum(tmp1)

tmp1 = torch.rand(3)
a3,b3,c3 = tmp1
d3 = -sum(tmp1)


x = torch.rand(1)
y = torch.rand(1)
z = torch.rand(1)
x = Variable(x,requires_grad=True)
y = Variable(y,requires_grad=True)
z = Variable(z,requires_grad=True)

lr = 0.0001
for i in range(20000):
    l1 = abs(a1*x+b1*y+c1*z+d1)/torch.sqrt(a1**2+b1**2+c1**2+d1**2)
    l2 = abs(a2*x+b2*y+c2*z+d2)/torch.sqrt(a2**2+b2**2+c2**2+d2**2)
    l3 = abs(a3*x+b3*y+c3*z+d3)/torch.sqrt(a3**2+b3**2+c3**2+d3**2)
    loss = l1 + l2 + l3
    loss.backward()
    x.data = x.data - lr*x.grad.data
    y.data = y.data - lr*y.grad.data
    z.data = z.data - lr*z.grad.data
    x.grad.data.zero_()
    y.grad.data.zero_()#重要
    z.grad.data.zero_()
x,y,z   

最後的結果是:

(tensor([1.0004], requires_grad=True),
 tensor([0.9996], requires_grad=True),
 tensor([0.9999], requires_grad=True))

同樣的框架可以用來做很多簡單的最優化問題

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章