A*算法,簡單實現八數碼問題

A*算法求解N數碼問題的設計與實現

Table of Contents

任務要求:

1. 關於A*算法:

2. 算法複雜度:

3. Solution:

4. CODE


任務要求:

  1. 以八數碼問題爲例實現A*算法的求解程序(編程語言不限),要求設計兩種以上的不同估價函數;
  2. 在求解八數碼問題的A*算法程序中,設置相同的初始狀態和目標狀態,針對不同的估價函數,求得問題的解,並比較它們對搜索算法性能的影響,包括擴展節點數、生成節點數等;
  3. 對於八8數碼問題,設置與上述2相同的初始狀態和目標狀態,用寬度優先搜索算法(即令估計代價h(n)=0的A*算法)求得問題的解,以及搜索過程中的擴展節點數、生成節點數;
  4. 上交源程序(要求有代碼註釋)。

1. 關於A*算法:

A*算法的核心代價函數的設計,loss = g + h,g爲當前狀態深度,h是關鍵,h代表當前狀態到達目標狀態的估計值,h必須滿足某些條件,而最重要的條件是h < r,r爲當前狀態到目標狀態的估計值。

以實例理解上述條件。例如在八數碼問題中,h可以被設計爲“各個數到目標狀態需要走的步數的和”,這顯然是小於真實需要步數和的,h也可以被設計爲“各個數和目標狀態不同的個數和”,這個條件顯然比第一個條件更加寬鬆,必然小於真實需要步數。以上即爲本次設計的兩個代價函數。

另一個例子是連續地圖中的尋路算法,h一般設計爲歐式距離,即直線距離,這顯然比真實需要走的路要短(可能有障礙,彎路等等)。

通過以上兩個例子,我們可以直觀地理解,爲什麼估計值h必須小於真實值,但要儘可能的大,接近真實值的下界。例如尋路算法中,你估算的距離越接近真實距離,那麼你啓發式找到可能的路徑就會越準確。理論上,可以嚴格證明滿足這些條件,必然可以找到最優解。

從另一個角度來審視A*算法,它可以視爲以代價爲步長的廣度優先算法,這一點要從代碼實現上才能感受到。每次都優先處理代價最小的狀態,如果觀察搜索樹,將會看到它在整個搜索數的節點之間無序跳動(選全局估計代價最小)。

再次,例如在尋路算法中,A*算法在實際運行中,類似於我們人類在找兩點之間的最快路徑,儘管我們無法確定中間要從何處繞開障礙,但是我們卻知道要忘目標靠。下面是A*算法運行實例,它在每個狀態,都會對目標計算一次歐式距離,以此約束選擇,就彷彿被歐式距離牽引到目標狀態一樣。

https://pic1.zhimg.com/80/v2-99dea949704bd4d7a52f397586954c04_720w.jpg

2. 算法複雜度:

  1. 廣度優先搜索,算法複雜度爲O(4^n),或者從另外的角度,八字碼的所有狀態爲n個數字的全排列,估計O(n!)。
  2. A*算法的好壞和代價函數h的設計密切相關,h必須儘可能小於並貼近真實代價,這樣朝目標貼近的方向越筆直(參考上圖),算法的一個上界和廣度搜索是一樣的,但實際上可以很快,我感覺在一些問題中可以接近線性複雜度。
  3. A*算法的空間消耗非常大,和實際複雜度類似,它需要保存每個狀態,同樣和h的設計相關。
  4. 過小的h的估計,會導致“這條路比較最短,我要深入下去,但其實這條路是錯的”這種現象,即會在錯誤的路上深入太深,h太小,那麼g就要越大才會發現走錯,越深的搜索,節點增長得越快,越接近指數級別。

3. Solution

 

init_state

[[3 1 7]

 [6 8 0]

 [4 2 5]]

target

[[7 2 1]

 [4 8 6]

 [3 0 5]]

----------A star with cost_function 1-----------

27 步之後能到達目標

生成節點數 5383

耗時:0.9075729846954346

----------A star with cost_function 2-----------

27 步之後能到達目標

生成節點數 181440

耗時:29.369495630264282

----------breadth search-----------

27 步之後能到達目標

生成節點數 9698

耗時:19.344292402267456

=========================================================

init_state

[[2 6 7]

 [5 0 3]

 [4 1 8]]

target

[[5 2 3]

 [4 1 0]

 [8 7 6]]

----------A star with cost_function 1-----------

16 步之後能到達目標

生成節點數 115

耗時:0.016954660415649414

----------A star with cost_function 2-----------

16 步之後能到達目標

生成節點數 132050

耗時:18.046760320663452

----------breadth search-----------

16 步之後能到達目標

生成節點數 3629

耗時:0.7639575004577637

===================================================================

init_state

[[2 8 7]

 [4 5 1]

 [0 6 3]]

target

[[5 7 4]

 [6 1 8]

 [3 2 0]]

----------A star with cost_function 1-----------

無解

耗時:0.0

----------A star with cost_function 2-----------

無解

耗時:0.0

----------breadth search-----------

無解

耗時:0.0009975433349609375

==================================================================

init_state

[[3 2 1]

 [6 7 5]

 [4 8 0]]

target

[[1 8 0]

 [5 2 6]

 [3 7 4]]

----------A star with cost_function 1-----------

25 步之後能到達目標

生成節點數 1805

耗時:0.26030421257019043

----------A star with cost_function 2-----------

27 步之後能到達目標

生成節點數 181440

耗時:31.541586875915527

----------breadth search-----------

25 步之後能到達目標

生成節點數 15633

耗時:16.068492650985718

4. CODE

import numpy as np
from queue import PriorityQueue, Queue
import time
np.random.seed(0)


# 計算g(n)
def g(cur):
    return cur.depth

# 計算h(n)
def h(cur):
    h = 0
    for x1 in range(3):
        for y1 in range(3):
            if cur.data[x1][y1] == 0:
                continue
            x2,y2 = n2posi[cur.data[x1][y1]]
            # print(x1,y1,np.abs(x1 - x2) + np.abs(y1 - y2))
            h += np.abs(x1 - x2) + np.abs(y1 - y2)
    return h

def h2(cur):
    return np.sum(cur.data == target_state.data)

# 計算cost(n)
def cost(cur):
    return g(cur) + h(cur)


def cost2(cur):
    return g(cur) + h2(cur)

# 測試是否有界
def is_solutable(init_, target):
    sum_a = 0
    sum_b = 0
    a = init_.data.reshape(-1)
    b = target.data.reshape(-1)
    for i in range(len(a)):
        sum_a += np.sum((a[i:] < a[i])&(a[i:]!=0))
    for i in range(len(b)):
        sum_b += np.sum((b[i:] < b[i])&(b[i:]!=0))
    
    # print(sum_a, sum_b)
    return (sum_a%2) == (sum_b%2)
    

# a star
def a_star(init_state, target_state, cost):
    if is_solutable(init_state,target_state) == False:
        print("無解")
        return None



    opens_ = PriorityQueue()  # 存放已觀察未訪問節點
    closes_ = PriorityQueue() # 存放已經訪問節點
    states = {}
    directions = [(-1,0), (0, -1), (1, 0), (0, 1)]

    opens_.put(init_state)

    cur = init_state

    while(True):
        # 獲得當前代價最小的進行訪問
        cur = opens_.get()
        closes_.put(cur)
        if cur == target_state:
            break
        for dx,dy in directions:
            x,y = cur.position()
            x_n,y_n = x+dx, y+dy
            # 越界跳出
            if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
                continue

            # 移動空白,創建新節點
            data = cur.data.copy()
            depth = cur.depth + 1
            data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
            new_state = State(data, depth, x_n, y_n)
            new_state.root = cur
            new_state.cost_ = cost(new_state)
            # 該狀態是否已經訪問過,如果是則更新狀態
            if new_state in states:
                if new_state.cost_ < states[new_state].cost_:
                    # 更新
                    states[new_state].root = cur
                    states[new_state].cost_ = new_state.cost_
                    # 更新兩個最小堆
                    opens_.put(opens_.get())
                    closes_.put(closes_.get())
            else:
                states[new_state] = new_state
                opens_.put(new_state)
        
    r = cur
    num = 0
    # 輸出路徑
    while True:
        if cur is None:
            break
        num += 1
        # print("----^----")
        # print(cur)
        cur = cur.root
    print('{} 步之後能到達目標'.format(num))
    print('生成節點數 {}'.format(len(states)))
    return r

# 廣度搜索
def bf_search(init_state, target_state):
    if is_solutable(init_state, target_state) == False:
        print("無解")
        return
    q = Queue()                # 隊列
    states = {}            # 訪問狀態記錄
    directions = [(-1,0), (0, -1), (1, 0), (0, 1)]

    states[init_state] = init_state
    q.put(init_state)

    while q.empty() != True:
        cur = q.get()
        if cur == target_state:
            break
        for dx,dy in directions:
            x,y = cur.position()
            x_n,y_n = x+dx, y+dy
            # 越界跳出
            if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
                continue    
            # 移動空白,創建新節點
            data = cur.data.copy()
            depth = cur.depth + 1
            data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
            new_state = State(data, depth, x_n, y_n)
            new_state.root = cur
            # 是否訪問過,如果是且更近更新狀態
            if new_state in states:
                if new_state.depth < states[new_state].depth:
                    states[new_state].depth = new_state.depth
                    states[new_state].root = cur

            else:
                states[new_state] = new_state
                q.put(new_state)
    r = cur
    num = 0
    # 輸出路徑
    while True:
        if cur is None:
            break
        num += 1
        # print("----^----")
        # print(cur)
        cur = cur.root
    print('{} 步之後能到達目標'.format(num))
    print('生成節點數 {}'.format(q.qsize()))

class State:

    def __init__(self, data, depth, blank_x, blank_y, cost_ = 1000):
        self.data = data
        # fac 用於狀態hash值的計算
        self.fac = np.array([10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000])
        self.depth = depth
        self.x = blank_x
        self.y = blank_y
        self.cost_ = cost_
        self.root = None

    # 返回空白位置
    def position(self):
        return self.x, self.y

    # 重載方法,便於在字典,最小堆等數據結構的使用自定義類
    def __hash__(self):
        return int(np.dot(self.data.reshape(-1), self.fac))
    
    def __eq__(self, other):
        return np.sum(self.data == other.data)==9
    
    def __str__(self):
        return str(self.data)

    def __lt__(self, other):
        return self.cost_ < other.cost_

# 目標狀態
target = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(target == 0)
target_state = State(target, 0, x[0], y[0], cost_ = 0)
# 保存目標矩陣各個數字的position
n2posi = {}
for i in range(3):
    for j in range(3):
        n2posi[target[i][j]] = (i,j)

'''
a_star 代價函數一
'''
cost_func = cost


# 初始狀態
init_ = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)

print("init_state")
print(init_state)
print("target")
print(target)

print('----------A star with cost_function 1-----------')

start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))

'''
a_star 代價函數二
'''
print('----------A star with cost_function 2-----------')
cost_func = cost2

# 初始狀態
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)

# print("init_state")
# print(init_state)
# print("target")
# print(target)


start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))

'''
廣搜
'''
print('----------breadth search-----------')

# 初始狀態
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)

# print("init_state")
# print(init_state)
# print("target")
# print(target)

start_time = time.time()
bf_search(init_state, target_state)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章