拋物線法
拋物線法也叫做二次插值法,基本思想爲:在搜索的區間中不斷的使用二次多相似去近似目標函數,並且逐步用插值多項式去逼近線搜索問題。具體推導見《最優化方法及其matlab程序設計》P18。
代碼如下
import numpy as np
import matplotlib.pyplot as plt
import math
def phi(x):
'''
測試函數1
:param x:
:return:
'''
return x * x - 2 * x + 1
def complicated_func(x):
'''
測試函數2
:param x:
:return:
'''
return x * x * x + 5 * math.sin(2 * x)
def parabolic_search(f, a, b, epsilon=1e-1):
'''
拋物線法,迭代函數
:param f: 目標函數
:param a: 起始點
:param b: 終止點
:param epsilon: 閾值
:return:
'''
h = (b - a) / 2
s0 = a
s1 = a + h
s2 = b
f0 = f(s0)
f1 = f(s1)
f2 = f(s2)
h_mean = (4 * f1 - 3 * f0 - f2) / (2 * (2 * f1 - f0 - f2)) * h
s_mean = s0 + h_mean
f_mean = f(s_mean)
# 調試
k = 0
while s2 - s0 > epsilon:
h = (s2 - s0) / 2
h_mean = (4 * f1 - 3 * f0 - f2) / (2 * (2 * f1 - f0 - f2)) * h
s_mean = s0 + h_mean
f_mean = f(s_mean)
if f1 <= f_mean:
if s1 < s_mean:
s2 = s_mean
f2 = f_mean
# 重新計算一次,書上並沒有寫,所以導致一直循環
s1 = (s2 + s0)/2
f1 = f(s1)
else:
s0 = s_mean
f0 = f_mean
s1 = (s2 + s0)/2
f1 = f(s1)
else:
if s1 > s_mean:
s2 = s1
s1 = s_mean
f2 = f1
f1 = f_mean
else:
s0 = s1
s1 = s_mean
f0 = f1
f1 = f_mean
# print([k, (s2 - s0), f_mean, s_mean])
print(k)
k += 1
return s_mean, f_mean
if __name__ == '__main__':
x = np.linspace(1, 3, 200)
y = []
index = 0
for i in x:
y.append(complicated_func(x[index]))
index += 1
plt.plot(x, y)
plt.show()
result = parabolic_search(complicated_func, 1.0, 3.0)
print(result)
# x = np.linspace(0, 2, 200)
# plt.plot(x, phi(x))
# plt.show()
# result = parabolic_search(phi, 0, 2.0)
# print(result)
算法結果
極值點:
(1.802896968512279, 3.6216601353779527)
代碼詳見:https://github.com/finepix/py_workspace/tree/master/optimization_algorithm