2D函數最小值優化

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch

def himmelblau(x):
    return (x[0] ** 2 + x[1] - 11) ** 2 + (x[0] + x[1] ** 2 -7) ** 2

x = np.arange(-6,6,0.1)
y = np.arange(-6,6,0.1)
print('x,y range:', x.shape, y.shape)
X,Y = np.meshgrid(x,y)
print('X,Y maps:',X.shape,Y.shape)
Z = himmelblau([X,Y])

fig = plt.figure('himmelblau')
ax = fig.gca(projection='3d')
ax.plot_surface(X,Y,Z)
ax.view_init(60,-30)
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.show()

#分別用[0.,0.],[4.,0.],[-4.,0.]對x進行初始化
x = torch.tensor([0.,0.], requires_grad=True)
optimizer = torch.optim.Adam([x], lr=1e-3)
for step in range(20000):
    pred = himmelblau(x)

    optimizer.zero_grad()  #將梯度手動清零,避免梯度累加
    pred.backward()   #反向傳播,計算當前的梯度值
    optimizer.step()  #根據梯度更新網絡參數

    if step % 2000 == 0:
        print('strp{}: x = {}, f(x) = {}'.format(step,x.tolist(),pred.item()))
x,y range: (120,) (120,)
X,Y maps: (120, 120) (120, 120)
strp0: x = [0.0009999999310821295, 0.0009999999310821295], f(x) = 170.0
strp2000: x = [2.3331806659698486, 1.9540692567825317], f(x) = 13.730920791625977
strp4000: x = [2.9820079803466797, 2.0270984172821045], f(x) = 0.014858869835734367
strp6000: x = [2.999983549118042, 2.0000221729278564], f(x) = 1.1074007488787174e-08
strp8000: x = [2.9999938011169434, 2.0000083446502686], f(x) = 1.5572823031106964e-09
strp10000: x = [2.999997854232788, 2.000002861022949], f(x) = 1.8189894035458565e-10
strp12000: x = [2.9999992847442627, 2.0000009536743164], f(x) = 1.6370904631912708e-11
strp14000: x = [2.999999761581421, 2.000000238418579], f(x) = 1.8189894035458565e-12
strp16000: x = [3.0, 2.0], f(x) = 0.0
strp18000: x = [3.0, 2.0], f(x) = 0.0

可以看到初始值爲(0,0)時函數的最小值的座標爲(3,2)。
在這裏插入圖片描述
而不同的初始值得到的結果也是不同的,所以初始值不能隨意設置。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章