精確線搜索-拋物線法python實現

拋物線法

拋物線法也叫做二次插值法,基本思想爲:在搜索的區間中不斷的使用二次多相似去近似目標函數,並且逐步用插值多項式去逼近線搜索問題。具體推導見《最優化方法及其matlab程序設計》P18。

代碼如下

import numpy as np
import matplotlib.pyplot as plt
import math


def phi(x):
    '''
        測試函數1
    :param x:
    :return:
    '''
    return x * x - 2 * x + 1


def complicated_func(x):
    '''
        測試函數2
    :param x:
    :return:
    '''
    return x * x * x + 5 * math.sin(2 * x)


def parabolic_search(f, a, b, epsilon=1e-1):
    '''
        拋物線法,迭代函數
    :param f: 目標函數
    :param a:   起始點
    :param b:   終止點
    :param epsilon: 閾值
    :return:
    '''
    h = (b - a) / 2
    s0 = a
    s1 = a + h
    s2 = b
    f0 = f(s0)
    f1 = f(s1)
    f2 = f(s2)
    h_mean = (4 * f1 - 3 * f0 - f2) / (2 * (2 * f1 - f0 - f2)) * h
    s_mean = s0 + h_mean
    f_mean = f(s_mean)
    # 調試
    k = 0
    while s2 - s0 > epsilon:
        h = (s2 - s0) / 2
        h_mean = (4 * f1 - 3 * f0 - f2) / (2 * (2 * f1 - f0 - f2)) * h
        s_mean = s0 + h_mean
        f_mean = f(s_mean)
        if f1 <= f_mean:
            if s1 < s_mean:
                s2 = s_mean
                f2 = f_mean
                # 重新計算一次,書上並沒有寫,所以導致一直循環
                s1 = (s2 + s0)/2
                f1 = f(s1)
            else:
                s0 = s_mean
                f0 = f_mean
                s1 = (s2 + s0)/2
                f1 = f(s1)
        else:
            if s1 > s_mean:
                s2 = s1
                s1 = s_mean
                f2 = f1
                f1 = f_mean
            else:
                s0 = s1
                s1 = s_mean
                f0 = f1
                f1 = f_mean
        # print([k, (s2 - s0), f_mean, s_mean])
        print(k)
        k += 1
    return s_mean, f_mean


if __name__ == '__main__':
    x = np.linspace(1, 3, 200)
    y = []
    index = 0
    for i in x:
        y.append(complicated_func(x[index]))
        index += 1
    plt.plot(x, y)
    plt.show()

    result = parabolic_search(complicated_func, 1.0, 3.0)
    print(result)

    # x = np.linspace(0, 2, 200)
    # plt.plot(x, phi(x))
    # plt.show()
    # result = parabolic_search(phi, 0, 2.0)
    # print(result)


算法結果

在這裏插入圖片描述

極值點:

(1.802896968512279, 3.6216601353779527)

代碼詳見:https://github.com/finepix/py_workspace/tree/master/optimization_algorithm

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章