一,八數碼問題簡介
- 編號爲1到8的8個正方形滑塊被擺成3行3列(有一個格子留空),可以每次把與空格相鄰(有公共邊)的滑塊移動到空格中,而它原來的位置就成了新的空格。給定局面,計算出從當前狀態移動到目標狀態的最少步數。如將八數碼從左到有從上到下的數字列出來,沒有空格用0表示(其實也可以用9表示),可以表示爲:
2 6 4 1 3 7 0 5 8 -> 8 1 5 7 3 6 4 0 2
在這裏我們的目標狀態是1 2 3 4 5 6 7 8 0 - 預備知識
康拓展開,因爲八數碼中有八個數,再加上一個用0表示的空格,可以看成是0到8的全排列,總共有362880個狀態,幸運的是康拓展開可以講一個1-n的排列對應到整數0到n!-1,其實就是這個排列在所有排列中出現的位置。
如2 4 1 3,這個排列在1到4的所有全排列的位置可以這樣計算,第一個位置比2小的全排列個數是1*(4-1)!=6;第一個位置是2,第二個位置比4小的全排列個數是(3-1)*(4-2)!=4,這裏3-1的原因是比4小的個數有3個,但是2已經出現在4前面了,就不必要在計算了,這樣就可以計算出2 4 1 3對應的數字是6+4+0+0=10。其實就是看這個數字後邊有多少個比自己小的數字,乘上後邊還有多少個位置的全排列值。代碼:
int getcode(int perm[], int len)
{
int ret = 0;
for (int i = 0; i < len; ++i) {
int cnt = 0;
for (int j = i + 1; j < len; ++j) {
if (perm[i] > perm[j]) {
++cnt;
}
}
// fac是存儲階乘值的數組
ret += fac[len - 1 - i] * cnt;
}
return ret;
}
還有另外一個小優化,就是我們忽略0之後(如果空格用9表示,則忽略9),每一步移動之後的全排列的逆序數的奇偶性是不變的,逆序數就是當前數字之後小於這個數的個數,全部加起來。
二,一般解法
就是使用一般的bfs,對於每一個狀態[2 6 4 1 3 7 0 5 8],擴展其可以轉到的狀態,使用康拓展開標示這個狀態以前有沒有出現過,用一個父節點數組存儲父節點,對於當前節點用一個方向數組存放從父節點到當前節點的移動方向,找到解之後,就按照存儲的父節點位置,得到整個路徑,最後輸出就行了。可以加上上邊提到的小優化。但是這個效率不高,在poj上可以通過,在hdu上會超時。代碼如下:
/*************************************************************************
> File Name: 1077.cpp
> Author: gwq
> Mail: [email protected]
> Created Time: 2015年08月12日 星期三 10時35分19秒
************************************************************************/
#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>
#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())
using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;
const double esp = 1e-5;
#define N 5
#define M 400000
int head, tail, orilen;
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
int st[M][9];
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
int vis[M];
char mm[] = "urdl";
char strtmp[100];
int oritmp[9];
int fac[20];
int fa[M];
char direct[M];
int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += cnt * fac[8 - i];
}
return res;
}
int try_insert(int idx)
{
int code = getcode(st[idx]);
if (vis[code]) {
return 0;
} else {
return vis[code] = 1;
}
}
void print(int t[])
{
for (int i = 0; i < 9; ++i) {
printf("%d ", t[i]);
}
printf("\n");
}
int bfs(void)
{
clr(vis, 0);
head = 1;
tail = 2;
memcpy(st[head], oritmp, sizeof(oritmp));
clr(fa, -1);
fa[1] = -1;
direct[1] = '*';
vis[getcode(st[head])] = 1;
while (head < tail) {
//print(st[head]);
if (memcmp(st[head], goal, sizeof(goal)) == 0) {
return head;
}
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[head][i] == 0) {
idx = i;
}
}
int x = idx / 3;
int y = idx % 3;
//printf("%d %d %d\n", idx, x, y);
//getchar();
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[tail], st[head], sizeof(st[head]));
swap(st[tail][nidx], st[tail][idx]);
fa[tail] = head;
direct[tail] = mm[i];
if (try_insert(tail)) {
++tail;
}
}
}
++head;
}
return 0;
}
/*
* poj 1077 可以ac,但是hdu1043是tle
*/
int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i <= 15; ++i) {
fac[i] = fac[i - 1] * i;
}
while (fgets(strtmp, 100, stdin) != NULL) {
int len = strlen(strtmp);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(strtmp[i])) {
oritmp[orilen++] = strtmp[i] - '0';
} else if (strtmp[i] == 'x') {
oritmp[orilen++] = 0;
}
}
int p = bfs();
if (p == 0) {
printf("unsolvable\n");
continue;
}
string ans;
//printf("%d\n", p);
while (fa[p] != -1) {
ans.pb(direct[p]);
p = fa[p];
}
reverse(ans.begin(), ans.end());
cout << ans << endl;
}
return 0;
}
爲了優化算法,可以預先處理得到所有狀態到目標狀態的路徑,最後,對於某一個狀態,直接輸出路徑就行了。搜索的時候,從目標狀態開始搜索,同時記錄路徑。因爲poj上數據比較少,用這種方法會超時。
代碼如下:
/*************************************************************************
> File Name: 1077pre.cpp
> Author: gwq
> Mail: [email protected]
> Created Time: 2015年08月12日 星期三 19時07分45秒
************************************************************************/
#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>
#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())
using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;
const double esp = 1e-5;
#define N 600000
int orilen = 9;
int goalen;
int ori[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
int goal[9];
int fac[20];
string path[N];
int vis[N];
int st[N][9], head, tail;
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
char buf[100];
int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = 0; j < i; ++j) {
if (s[j] < s[i]) {
++cnt;
}
}
res += cnt * fac[s[i] - 1];
}
return res;
}
void bfs(void)
{
head = 1;
tail = 2;
clr(vis, 0);
memcpy(st[head], ori, sizeof(ori));
int code = getcode(st[head]);
vis[code] = 1;
path[code] = "";
while (head < tail) {
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[head][i] == 9) {
idx = i;
}
}
int x = idx / 3;
int y = idx % 3;
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[tail], st[head], sizeof(st[head]));
swap(st[tail][nidx], st[tail][idx]);
code = getcode(st[head]);
int ncode = getcode(st[tail]);
if (!vis[ncode]) {
path[ncode] = path[code]
+ mm[(i + 2) % 4];
vis[ncode] = 1;
++tail;
}
}
}
++head;
}
}
/*
1. 從終點擴展,記錄所用的路徑
2. hdu1043可以過
*/
int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
bfs();
while (fgets(buf, 100, stdin) != NULL) {
int len = strlen(buf);
goalen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(buf[i])) {
goal[goalen++] = buf[i] - '0';
} else if (buf[i] == 'x') {
goal[goalen++] = 9;
}
}
int code = getcode(goal);
if (vis[code]) {
int len = path[code].size();
for (int i = len - 1; i >= 0; --i) {
printf("%c", path[code][i]);
}
printf("\n");
} else {
printf("unsolvable\n");
}
}
return 0;
}
三,a*算法
先來介紹一下A算法,在BFS算法中,若對每個狀態n都設定估價函數f(n)=g(n)+h(n),並且每次從Open表中選節點進行擴展 時,都選取f值最小的節點,則稱該搜索算法爲啓發式搜索算法,又稱A算法。
在估價函數f(n)中,g(n)是從起始狀態到當前狀態n的代價,h(n)是從當前狀態n到目標狀態的估計代價。
A算法中若對估價函數選取不當,則可能找不到解,或者找到的解不是最優解。因此,需要對估價函數做一些限制,使得算法確保找到最優解。A*算法即爲對估價函數做了特定限制,且確保找到最優解的A算法。
f*(n) = g*(n) + h*(n),其中,f*(n)是從初始節點S0出發,經過節點n到達目標節點的最小步數(真實值),g*(n)是從S0出發,到達n的最小步數(真實值),h*(n)是從n出發,到達目標節點的最少步數(真實值),而估價函數f(n)是f*(n)的估計值。
f(n)=g(n)+h(n),且滿足:g(n)是從S0到n的真實步數(未必是最優的),因此,g(n)>0且g(n)>=g*(n),h(n)是從n到目標的估計步數,估計總是過於樂觀的,即h(n)<=h*(n),且h(n)相容,則A算法轉變成A*算法。
h(n)相容是指,如果對任意s1到s2滿足h(s1)<=h(s2)+c(s1,s2),其中c(s1,s2)是s1轉移到s2的步數,則稱h是相容的。h相容能確保隨着一步步往前走,f遞增,這樣A*能更高效的找到最優解。一般來說,在滿足h(n)<=h*(n)的前提下,h(n)的值越大越好。
一般用從當前節點到目標節點的直線距離或者曼哈頓距離作爲估值函數h,但也要具體問題具體分析。
下面是僞代碼(原文鏈接在這裏):
OPEN = priority queue containing START
CLOSED = empty set
while lowest rank in OPEN is not the GOAL:
current = remove lowest rank item from OPEN
add current to CLOSED
for neighbors of current:
cost = g(current) + movementcost(current, neighbor)
if neighbor in OPEN and cost less than g(neighbor):
remove neighbor from OPEN, because new path is better
if neighbor in CLOSED and cost less than g(neighbor): **
remove neighbor from CLOSED
if neighbor not in OPEN and neighbor not in CLOSED:
set g(neighbor) to cost
add neighbor to OPEN
set priority queue rank to g(neighbor) + h(neighbor)
set neighbor.s parent to current
reconstruct reverse path from goal to start
by following parent pointers
但實際上我們平常寫的A*並不是這個樣子,而是和普通的bfs類似,講fifo隊列換成優先隊列,其他的類似。
使用A*算法的代碼如下,估值函數用的是曼哈頓距離:
/*************************************************************************
> File Name: 1043_astar.cpp
> Author: gwq
> Mail: [email protected]
> Created Time: 2015年08月13日 星期四 16時44分33秒
************************************************************************/
#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>
#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())
using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;
const double esp = 1e-5;
#define N 400000
int ori[9];
int orilen;
int oripos;
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
int goalpos = 8;
int goalcode;
int vis[N];
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
char str[100];
int fac[20];
int pre[N];
char direct[N];
int len;
int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += fac[8 - i] * cnt;
}
return res;
}
struct Node {
int perm[9];
int h, g, x, y, st, pos, f;
Node(int s[], int hh, int gg, int xx, int yy, int sst, int ppos)
{
memcpy(perm, s, sizeof(perm));
h = hh;
g = gg;
f = g + h;
x = xx;
y = yy;
st = sst;
pos = ppos;
}
Node() {}
void output(void)
{
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
printf("\n");
}
printf("%d ", perm[i]);
}
}
bool check(void)
{
if (st == goalcode) {
return true;
} else {
return false;
}
}
};
int geth(int s[])
{
int h = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] == 0) {
continue;
}
int x = (s[i] - 1) / 3;
int y = (s[i] - 1) % 3;
int nx = i / 3;
int ny = i % 3;
h += abs(x - nx) + abs(y - ny);
}
return h;
}
bool operator <(Node u, Node v)
{
return u.h != v.h ? u.h > v.h : u.g > v.g;
}
int check(int s[])
{
int num = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] != 0) {
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j] && s[j] != 0) {
++num;
}
}
}
}
return num % 2;
}
void bfs(void)
{
priority_queue<Node> q;
clr(vis, 0);
clr(pre, -1);
clr(direct, '*');
int code = getcode(ori);
Node u = Node(ori, geth(ori), 0, oripos / 3, oripos % 3, code, oripos);
vis[code] = 1;
q.push(u);
while (!q.empty()) {
u = q.top();
q.pop();
//u.output();
//getchar();
if (u.check()) {
string path;
int p = u.st;
while (pre[p] != -1) {
path += direct[p];
p = pre[p];
}
reverse(path.begin(), path.end());
printf("%s\n", path.c_str());
return;
}
for (int i = 0; i < 4; ++i) {
int nx = u.x + dx[i];
int ny = u.y + dy[i];
int npos = nx * 3 + ny;
Node v;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(v.perm, u.perm, sizeof(u.perm));
swap(v.perm[npos], v.perm[u.pos]);
int nh = geth(v.perm);
int ng = u.g + 1;
int ncode = getcode(v.perm);
v.h = nh;
v.g = ng;
v.f = v.h + v.g;
v.x = nx;
v.y = ny;
v.pos = npos;
v.st = ncode;
if (!vis[ncode] && !check(v.perm)) {
pre[ncode] = u.st;
direct[ncode] = mm[i];
q.push(v);
vis[ncode] = 1;
}
}
}
}
cout << "unsolvable" << endl;
}
int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
goalcode = getcode(goal);
while (fgets(str, 100, stdin) != NULL) {
len = strlen(str);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(str[i])) {
ori[orilen++] = str[i] - '0';
} else if (str[i] == 'x') {
oripos = orilen;
ori[orilen++] = 0;
}
}
if (check(ori)) {
printf("unsolvable\n");
continue;
}
bfs();
}
return 0;
}
四,DBFS雙向廣度優先搜索算法(參考pdf)
DBFS算法是對BFS算法的一種擴展。BFS算法以廣度優先的順序不斷擴展直到遇到目標節點。DBFS算法從起始節點和目標節點兩個方向以廣度優先的順序同時擴展,直到一個隊列中已經出現了另一個隊列中已經擴展了的節點,也就相當於兩個擴展方向有了交點,那麼可以認爲找到了一條路徑。
DBFS算法相對於BFS算法,因爲採用了雙向擴展的方法,搜索樹的寬度得到了明顯的減少,時間和空間複雜度都有了明顯的提高。DBFS每次選擇節點數比較少的那邊進行擴展,並不是機械的進行擴展。
DBFS框架:
void dbfs()
{
1. 將起始節點放入隊列q0 ,將目標節點放入隊列q1;
2. 當兩個隊列都未空時,作如下循環:
1) 如果隊列q0裏的節點比q1中的少,則擴展隊列q0;
2) 否則擴展隊列q1
3. 如果隊列q0未空,不斷擴展q0直到爲空;
4. 如果隊列q1未空,不斷擴展q1直到爲空;
}
這道題目的代碼如下:
/*************************************************************************
> File Name: 1077dbfs.cpp
> Author: gwq
> Mail: [email protected]
> Created Time: 2015年08月12日 星期三 17時09分43秒
************************************************************************/
#include <cmath>
#include <ctime>
#include <cctype>
#include <climits>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <algorithm>
#define INF (INT_MAX / 10)
#define clr(arr, val) memset(arr, val, sizeof(arr))
#define pb push_back
#define sz(a) ((int)(a).size())
using namespace std;
typedef set<int> si;
typedef vector<int> vi;
typedef map<int, int> mii;
typedef pair<int, int> pii;
typedef long long ll;
const double esp = 1e-5;
#define M 400000
int orilen;
int oritmp[9];
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
char mm[] = "urdl";
int goal[9] = {1, 2, 3, 4, 5, 6, 7, 8, 0};
char strtmp[20];
int fac[20], vis[2][M];
int st[2][M][9];
int head[2];
int tail[2];
int fa[2][M];
char direct[2][M];
int getcode(int s[])
{
int res = 0;
for (int i = 0; i < 9; ++i) {
int cnt = 0;
for (int j = i + 1; j < 9; ++j) {
if (s[i] > s[j]) {
++cnt;
}
}
res += cnt * fac[8 - i];
}
return res;
}
// 忽略0之後,不改變排列的奇偶性
int check(int s[])
{
int num = 0;
for (int i = 0; i < 9; ++i) {
if (s[i] == 0) {
continue;
}
for (int j = i + 1; j < 9; ++j) {
if (s[j] != 0 && s[i] > s[j]) {
++num;
}
}
}
return num % 2;
}
void dbfs(void)
{
head[0] = 1;
head[1] = 1;
tail[0] = 2;
tail[1] = 2;
clr(vis, 0);
memcpy(st[0][1], oritmp, sizeof(oritmp));
memcpy(st[1][1], goal, sizeof(goal));
fa[0][1] = -1;
fa[1][1] = -1;
direct[0][1] = '*';
direct[1][1] = '*';
int code0 = getcode(st[0][1]);
int code1 = getcode(st[1][1]);
vis[0][code0] = 1;
vis[1][code1] = 1;
while (head[0] < tail[0] && head[1] < tail[1]) {
int no = 0;
if (head[0] == tail[0]) {
no = 1;
} else if (head[1] == tail[1]) {
no = 0;
} else {
if (tail[0] - head[0] < tail[1] - head[1]) {
no = 0;
} else {
no = 1;
}
}
int ono = 1 - no;
int code = getcode(st[no][head[no]]);
//printf("\n%d..%d", code, no);
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
//printf("\n");
}
//printf("%d ", st[no][head[no]][i]);
}
if (vis[ono][code]) {
//printf("done\n");
string ans;
int pos = head[no];
if (no) {
for (int i = 0; i < tail[0]; ++i) {
int tmp = getcode(st[0][i]);
if (tmp == code) {
pos = i;
break;
}
}
} else {
pos = head[no];
}
int p = pos;
//printf("\n%d.....\n", pos);
while (fa[0][p] != -1) {
ans += direct[0][p];
p = fa[0][p];
}
reverse(ans.begin(), ans.end());
//cout << ans << endl;
if (no == 0) {
for (int i = 0; i < tail[1]; ++i) {
int tmp = getcode(st[1][i]);
if (tmp == code) {
pos = i;
break;
}
}
} else {
pos = head[no];
}
p = pos;
//printf("%d.....%d\n", pos, head[no]);
while (fa[1][p] != -1) {
ans += direct[1][p];
p = fa[1][p];
}
printf("%s\n", ans.c_str());
return;
}
int idx = 0;
for (int i = 0; i < 9; ++i) {
if (st[no][head[no]][i] == 0) {
idx = i;
break;
}
}
int x = idx / 3;
int y = idx % 3;
for (int i = 0; i < 4; ++i) {
int nx = x + dx[i];
int ny = y + dy[i];
int nidx = nx * 3 + ny;
if (nx >= 0 && nx < 3 && ny >= 0 && ny < 3) {
memcpy(st[no][tail[no]], st[no][head[no]], sizeof(st[no][head[no]]));
swap(st[no][tail[no]][idx], st[no][tail[no]][nidx]);
int ncode = getcode(st[no][tail[no]]);
if (!vis[no][ncode]) {
vis[no][ncode] = 1;
fa[no][tail[no]] = head[no];
direct[no][tail[no]] = mm[no ? (i + 2) % 4 : i];
++tail[no];
}
}
}
++head[no];
}
printf("unsolvable\n");
}
int main(int argc, char *argv[])
{
fac[0] = 1;
for (int i = 1; i < 20; ++i) {
fac[i] = fac[i - 1] * i;
}
while (fgets(strtmp, 20, stdin) != NULL) {
//printf("fgets\n");
int len = strlen(strtmp);
orilen = 0;
for (int i = 0; i < len; ++i) {
if (isdigit(strtmp[i])) {
oritmp[orilen++] = strtmp[i] - '0';
} else if (strtmp[i] == 'x') {
oritmp[orilen++] = 0;
}
}
for (int i = 0; i < 9; ++i) {
if (i % 3 == 0) {
//printf("\n");
}
//printf("%d ", oritmp[i]);
}
//printf("\n");
//printf("%d\n", getcode(oritmp));
if (check(oritmp)) {
printf("unsolvable\n");
continue;
}
dbfs();
}
return 0;
}