A*算法求解N數碼問題的設計與實現
Table of Contents
任務要求:
- 以八數碼問題爲例實現A*算法的求解程序(編程語言不限),要求設計兩種以上的不同估價函數;
- 在求解八數碼問題的A*算法程序中,設置相同的初始狀態和目標狀態,針對不同的估價函數,求得問題的解,並比較它們對搜索算法性能的影響,包括擴展節點數、生成節點數等;
- 對於八8數碼問題,設置與上述2相同的初始狀態和目標狀態,用寬度優先搜索算法(即令估計代價h(n)=0的A*算法)求得問題的解,以及搜索過程中的擴展節點數、生成節點數;
- 上交源程序(要求有代碼註釋)。
1. 關於A*算法:
A*算法的核心代價函數的設計,loss = g + h,g爲當前狀態深度,h是關鍵,h代表當前狀態到達目標狀態的估計值,h必須滿足某些條件,而最重要的條件是h < r,r爲當前狀態到目標狀態的估計值。
以實例理解上述條件。例如在八數碼問題中,h可以被設計爲“各個數到目標狀態需要走的步數的和”,這顯然是小於真實需要步數和的,h也可以被設計爲“各個數和目標狀態不同的個數和”,這個條件顯然比第一個條件更加寬鬆,必然小於真實需要步數。以上即爲本次設計的兩個代價函數。
另一個例子是連續地圖中的尋路算法,h一般設計爲歐式距離,即直線距離,這顯然比真實需要走的路要短(可能有障礙,彎路等等)。
通過以上兩個例子,我們可以直觀地理解,爲什麼估計值h必須小於真實值,但要儘可能的大,接近真實值的下界。例如尋路算法中,你估算的距離越接近真實距離,那麼你啓發式找到可能的路徑就會越準確。理論上,可以嚴格證明滿足這些條件,必然可以找到最優解。
從另一個角度來審視A*算法,它可以視爲以代價爲步長的廣度優先算法,這一點要從代碼實現上才能感受到。每次都優先處理代價最小的狀態,如果觀察搜索樹,將會看到它在整個搜索數的節點之間無序跳動(選全局估計代價最小)。
再次,例如在尋路算法中,A*算法在實際運行中,類似於我們人類在找兩點之間的最快路徑,儘管我們無法確定中間要從何處繞開障礙,但是我們卻知道要忘目標靠。下面是A*算法運行實例,它在每個狀態,都會對目標計算一次歐式距離,以此約束選擇,就彷彿被歐式距離牽引到目標狀態一樣。
2. 算法複雜度:
- 廣度優先搜索,算法複雜度爲O(4^n),或者從另外的角度,八字碼的所有狀態爲n個數字的全排列,估計O(n!)。
- A*算法的好壞和代價函數h的設計密切相關,h必須儘可能小於並貼近真實代價,這樣朝目標貼近的方向越筆直(參考上圖),算法的一個上界和廣度搜索是一樣的,但實際上可以很快,我感覺在一些問題中可以接近線性複雜度。
- A*算法的空間消耗非常大,和實際複雜度類似,它需要保存每個狀態,同樣和h的設計相關。
- 過小的h的估計,會導致“這條路比較最短,我要深入下去,但其實這條路是錯的”這種現象,即會在錯誤的路上深入太深,h太小,那麼g就要越大才會發現走錯,越深的搜索,節點增長得越快,越接近指數級別。
3. Solution:
init_state
[[3 1 7]
[6 8 0]
[4 2 5]]
target
[[7 2 1]
[4 8 6]
[3 0 5]]
----------A star with cost_function 1-----------
27 步之後能到達目標
生成節點數 5383
耗時:0.9075729846954346
----------A star with cost_function 2-----------
27 步之後能到達目標
生成節點數 181440
耗時:29.369495630264282
----------breadth search-----------
27 步之後能到達目標
生成節點數 9698
耗時:19.344292402267456
=========================================================
init_state
[[2 6 7]
[5 0 3]
[4 1 8]]
target
[[5 2 3]
[4 1 0]
[8 7 6]]
----------A star with cost_function 1-----------
16 步之後能到達目標
生成節點數 115
耗時:0.016954660415649414
----------A star with cost_function 2-----------
16 步之後能到達目標
生成節點數 132050
耗時:18.046760320663452
----------breadth search-----------
16 步之後能到達目標
生成節點數 3629
耗時:0.7639575004577637
===================================================================
init_state
[[2 8 7]
[4 5 1]
[0 6 3]]
target
[[5 7 4]
[6 1 8]
[3 2 0]]
----------A star with cost_function 1-----------
無解
耗時:0.0
----------A star with cost_function 2-----------
無解
耗時:0.0
----------breadth search-----------
無解
耗時:0.0009975433349609375
==================================================================
init_state
[[3 2 1]
[6 7 5]
[4 8 0]]
target
[[1 8 0]
[5 2 6]
[3 7 4]]
----------A star with cost_function 1-----------
25 步之後能到達目標
生成節點數 1805
耗時:0.26030421257019043
----------A star with cost_function 2-----------
27 步之後能到達目標
生成節點數 181440
耗時:31.541586875915527
----------breadth search-----------
25 步之後能到達目標
生成節點數 15633
耗時:16.068492650985718
4. CODE
import numpy as np
from queue import PriorityQueue, Queue
import time
np.random.seed(0)
# 計算g(n)
def g(cur):
return cur.depth
# 計算h(n)
def h(cur):
h = 0
for x1 in range(3):
for y1 in range(3):
if cur.data[x1][y1] == 0:
continue
x2,y2 = n2posi[cur.data[x1][y1]]
# print(x1,y1,np.abs(x1 - x2) + np.abs(y1 - y2))
h += np.abs(x1 - x2) + np.abs(y1 - y2)
return h
def h2(cur):
return np.sum(cur.data == target_state.data)
# 計算cost(n)
def cost(cur):
return g(cur) + h(cur)
def cost2(cur):
return g(cur) + h2(cur)
# 測試是否有界
def is_solutable(init_, target):
sum_a = 0
sum_b = 0
a = init_.data.reshape(-1)
b = target.data.reshape(-1)
for i in range(len(a)):
sum_a += np.sum((a[i:] < a[i])&(a[i:]!=0))
for i in range(len(b)):
sum_b += np.sum((b[i:] < b[i])&(b[i:]!=0))
# print(sum_a, sum_b)
return (sum_a%2) == (sum_b%2)
# a star
def a_star(init_state, target_state, cost):
if is_solutable(init_state,target_state) == False:
print("無解")
return None
opens_ = PriorityQueue() # 存放已觀察未訪問節點
closes_ = PriorityQueue() # 存放已經訪問節點
states = {}
directions = [(-1,0), (0, -1), (1, 0), (0, 1)]
opens_.put(init_state)
cur = init_state
while(True):
# 獲得當前代價最小的進行訪問
cur = opens_.get()
closes_.put(cur)
if cur == target_state:
break
for dx,dy in directions:
x,y = cur.position()
x_n,y_n = x+dx, y+dy
# 越界跳出
if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
continue
# 移動空白,創建新節點
data = cur.data.copy()
depth = cur.depth + 1
data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
new_state = State(data, depth, x_n, y_n)
new_state.root = cur
new_state.cost_ = cost(new_state)
# 該狀態是否已經訪問過,如果是則更新狀態
if new_state in states:
if new_state.cost_ < states[new_state].cost_:
# 更新
states[new_state].root = cur
states[new_state].cost_ = new_state.cost_
# 更新兩個最小堆
opens_.put(opens_.get())
closes_.put(closes_.get())
else:
states[new_state] = new_state
opens_.put(new_state)
r = cur
num = 0
# 輸出路徑
while True:
if cur is None:
break
num += 1
# print("----^----")
# print(cur)
cur = cur.root
print('{} 步之後能到達目標'.format(num))
print('生成節點數 {}'.format(len(states)))
return r
# 廣度搜索
def bf_search(init_state, target_state):
if is_solutable(init_state, target_state) == False:
print("無解")
return
q = Queue() # 隊列
states = {} # 訪問狀態記錄
directions = [(-1,0), (0, -1), (1, 0), (0, 1)]
states[init_state] = init_state
q.put(init_state)
while q.empty() != True:
cur = q.get()
if cur == target_state:
break
for dx,dy in directions:
x,y = cur.position()
x_n,y_n = x+dx, y+dy
# 越界跳出
if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
continue
# 移動空白,創建新節點
data = cur.data.copy()
depth = cur.depth + 1
data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
new_state = State(data, depth, x_n, y_n)
new_state.root = cur
# 是否訪問過,如果是且更近更新狀態
if new_state in states:
if new_state.depth < states[new_state].depth:
states[new_state].depth = new_state.depth
states[new_state].root = cur
else:
states[new_state] = new_state
q.put(new_state)
r = cur
num = 0
# 輸出路徑
while True:
if cur is None:
break
num += 1
# print("----^----")
# print(cur)
cur = cur.root
print('{} 步之後能到達目標'.format(num))
print('生成節點數 {}'.format(q.qsize()))
class State:
def __init__(self, data, depth, blank_x, blank_y, cost_ = 1000):
self.data = data
# fac 用於狀態hash值的計算
self.fac = np.array([10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000])
self.depth = depth
self.x = blank_x
self.y = blank_y
self.cost_ = cost_
self.root = None
# 返回空白位置
def position(self):
return self.x, self.y
# 重載方法,便於在字典,最小堆等數據結構的使用自定義類
def __hash__(self):
return int(np.dot(self.data.reshape(-1), self.fac))
def __eq__(self, other):
return np.sum(self.data == other.data)==9
def __str__(self):
return str(self.data)
def __lt__(self, other):
return self.cost_ < other.cost_
# 目標狀態
target = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(target == 0)
target_state = State(target, 0, x[0], y[0], cost_ = 0)
# 保存目標矩陣各個數字的position
n2posi = {}
for i in range(3):
for j in range(3):
n2posi[target[i][j]] = (i,j)
'''
a_star 代價函數一
'''
cost_func = cost
# 初始狀態
init_ = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
print("init_state")
print(init_state)
print("target")
print(target)
print('----------A star with cost_function 1-----------')
start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))
'''
a_star 代價函數二
'''
print('----------A star with cost_function 2-----------')
cost_func = cost2
# 初始狀態
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
# print("init_state")
# print(init_state)
# print("target")
# print(target)
start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))
'''
廣搜
'''
print('----------breadth search-----------')
# 初始狀態
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
# print("init_state")
# print(init_state)
# print("target")
# print(target)
start_time = time.time()
bf_search(init_state, target_state)
end_time = time.time()
print("耗時:{}".format(end_time - start_time))