pytorch的自動求導很好用,可以利用它對一些求導困難的問題做一些最優化問題,比如昨天狗菜提了一個問題:
求一個三維點的位置,使得它到一個直線族(三維)的距離之和最小
實際上就是求如下最優化問題
這個含絕對值的求導不太好做,如果交給Pytorch的話就比較容易了,構造一個簡單的例子,即過(1,1,1)的三條直線,顯然,最優解就是(1,1,1):
from torch.autograd import Variable
import torch
#Ax+By+Cz+(-A-B-C)=0
tmp1 = torch.rand(3)
a1,b1,c1 = tmp1
d1 = -sum(tmp1)
tmp1 = torch.rand(3)
a2,b2,c2 = tmp1
d2 = -sum(tmp1)
tmp1 = torch.rand(3)
a3,b3,c3 = tmp1
d3 = -sum(tmp1)
x = torch.rand(1)
y = torch.rand(1)
z = torch.rand(1)
x = Variable(x,requires_grad=True)
y = Variable(y,requires_grad=True)
z = Variable(z,requires_grad=True)
lr = 0.0001
for i in range(20000):
l1 = abs(a1*x+b1*y+c1*z+d1)/torch.sqrt(a1**2+b1**2+c1**2+d1**2)
l2 = abs(a2*x+b2*y+c2*z+d2)/torch.sqrt(a2**2+b2**2+c2**2+d2**2)
l3 = abs(a3*x+b3*y+c3*z+d3)/torch.sqrt(a3**2+b3**2+c3**2+d3**2)
loss = l1 + l2 + l3
loss.backward()
x.data = x.data - lr*x.grad.data
y.data = y.data - lr*y.grad.data
z.data = z.data - lr*z.grad.data
x.grad.data.zero_()
y.grad.data.zero_()#重要
z.grad.data.zero_()
x,y,z
最後的結果是:
(tensor([1.0004], requires_grad=True),
tensor([0.9996], requires_grad=True),
tensor([0.9999], requires_grad=True))
同樣的框架可以用來做很多簡單的最優化問題