一、理論篇
如下圖所示,是一個有7個頂點的圖,每條邊權值均爲1,試問從點0點6,有多少條最短路徑呢,分別是什麼?
我們可以直觀的看出來,一共有4條最短路徑,分別是
- 0->1->4->6
- 0->2->4->6
- 0->2->5->6
- 0->3->5->6
那麼問題來啦,我們要如何記錄下這些所有的最短路徑呢,以前在只需記錄一條最短路徑的情況下,我們只需要一個數組來記錄當前節點的前驅節點是什麼,最後再通過不斷的獲取當前節點的前驅節點來獲得路徑。在記錄所有的最短路徑時,我們需要建立一個二維數組,用於存儲當前節點可以由哪些節點得到,用最短路算法得到這樣一個二維數組後,我們便可以通過DFS得到所有路徑。
具體而言,針對上面的圖,我們建立一個二維數組vector<int> pre[7]
來存放0到每個節點的最短路可以由哪些前驅節點得到,那麼可知
- pre[0] = {}
- pre[1] = {0}
- pre[2] = {0}
- pre[3] = {0}
- pre[4] = {1, 2}
- pre[5] = {2, 3}
- pre[6] = {4, 5}
這樣,我們對pre進行dfs,就可以得到這樣一棵遞歸樹。這棵樹就是我們的解空間。
二、代碼篇
現在我們來一步一步寫程序實現這個求解過程,主要分爲兩步,第一步是用Dijistra得到pre數組,第二步是對pre進行DFS獲得所有路徑。首先,我們看看在Dijistra的最優子結構中,pre如何得到和更新。
if(dis[v] > dis[u] + G[u][v])
{
dis[v] = dis[u] + G[u][v];
pre[v].clear(); //這裏要記得clear
pre[v].push_back(u);
}
else if(dis[v] == dis[u] + G[u][v])
{
pre[v].push_back(u);
}
如果 dis[u] + G[u][v] < dis[v], 說明以u爲中介點可以使 dis[v]更優,此時需要令v的前驅結點爲u。並且即便原先 pre[v]中已經存放了若干結點,此處也應當先清空,然後再添加u。之後,如果 dis[u] + G[u][v] = dis[v], 說明以 u 爲中介點可以找到一條距離相同的路徑,因此v的前驅結點需要在原先的基礎上添加上 u 結點(而不必先清空pre[v])。完整的Dijistra如下,我用的是隊列優化的Dijistra,不優化的話直接將上述最優子結構加在普通的Dijistra上就行:
#include <iostream>
#include <vector>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 1e4+10;
vector<pair<int, int> > E[maxn];
vector<int> pre[maxn], temp;
int vis[maxn], dis[maxn];
void dijstra(int s)
{
fill(dis, dis+maxn, 0x3f3f3f3f);
dis[s] = 0;
priority_queue<pair<int, int> > q;
q.push(make_pair(0, s));
while(!q.empty())
{
int u = q.top().second;
q.pop();
if(vis[u]==1) continue;
vis[u] = 1;
for (int i = 0; i < E[u].size(); ++i)
{
int v = E[u][i].first, w = E[u][i].second;
if(dis[v]>dis[u]+w)
{
dis[v] = dis[u]+w;
pre[v].clear();
pre[v].push_back(u);
if(vis[v]==0) q.push(make_pair(-dis[v], v));
}
else if(dis[v]==dis[u]+w)
{
pre[v].push_back(u);
if(vis[v]==0) q.push(make_pair(-dis[v], v));
}
}
}
}
int main(int argc, char const *argv[])
{
int N, M, s, t; // 點數,邊數,起點,終點
cin >> N >> M >> s >> t;
while(M--)
{
int u, v, w;
cin >> u >> v >> w;
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
dijstra(s);
for (int i = 0; i < N; ++i)
{
printf("pre[%d]: ", i);
for (int j = 0; j < pre[i].size(); ++j)
{
printf("%d ", pre[i][j]);
}
printf("\n");
}
return 0;
}
/*
輸入
7 9 0 6
0 1 1
0 2 1
0 3 1
1 4 1
2 4 1
2 5 1
3 5 1
4 6 1
5 6 1
*/
第二部分是對pre進行dfs,獲得所有的最短路徑。我們從終點t開始dfs,當終點t等於起點s時輸出路徑,否則取出其前驅節點繼續dfs。在程序裏面我將註釋寫得很清楚啦
void dfs(int s, int t)
{
if(s==t) // 當到達起始節點,表明找到啦一條路徑,輸出
{
temp.push_back(t);
// 輸出路徑,若不想輸出可以用一個二維vector存儲每條路徑,注意是倒序
for (int i = temp.size()-1; i >= 0; i--){
cout << temp[i] << " ";
}
cout << endl;
temp.pop_back(); // 將剛加入的節點刪除
return;
}
temp.push_back(t); // 將當前訪問的節點加入臨時路徑temp後面
for (int i = 0; i < pre[t].size(); ++i){
dfs(s, pre[t][i]);
}
temp.pop_back(); // 遍歷完v點所有的前驅節點,將v刪除
}
總的求解程序如下
#include <iostream>
#include <vector>
#include <queue>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn = 1e4+10;
vector<pair<int, int> > E[maxn];
vector<int> pre[maxn], temp;
int vis[maxn], dis[maxn];
void dijstra(int s)
{
fill(dis, dis+maxn, 0x3f3f3f3f);
dis[s] = 0;
priority_queue<pair<int, int> > q;
q.push(make_pair(0, s));
while(!q.empty())
{
int u = q.top().second;
q.pop();
if(vis[u]==1) continue;
vis[u] = 1;
for (int i = 0; i < E[u].size(); ++i)
{
int v = E[u][i].first, w = E[u][i].second;
if(dis[v]>dis[u]+w)
{
dis[v] = dis[u]+w;
pre[v].clear();
pre[v].push_back(u);
if(vis[v]==0) q.push(make_pair(-dis[v], v));
}
else if(dis[v]==dis[u]+w)
{
pre[v].push_back(u);
if(vis[v]==0) q.push(make_pair(-dis[v], v));
}
}
}
}
void dfs(int s, int t)
{
if(s==t) // 當到達起始節點,表明找到啦一條路徑,輸出
{
temp.push_back(t);
// 輸出路徑,若不想輸出可以用一個二維vector存儲每條路徑
for (int i = temp.size()-1; i >= 0; i--){
cout << temp[i] << " ";
}
cout << endl;
temp.pop_back(); // 將剛加入的節點刪除
return;
}
temp.push_back(t); // 將當前訪問的節點加入臨時路徑temp後面
for (int i = 0; i < pre[t].size(); ++i){
dfs(s, pre[t][i]);
}
temp.pop_back(); // 遍歷完v點所有的前驅節點,將v刪除
}
int main(int argc, char const *argv[])
{
int N, M, s, t;
cin >> N >> M >> s >> t;
while(M--)
{
int u, v, w;
cin >> u >> v >> w;
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
dijstra(s);
dfs(s, t);
return 0;
}
/*
7 9 0 6
0 1 1
0 2 1
0 3 1
1 4 1
2 4 1
2 5 1
3 5 1
4 6 1
5 6 1
*/
運行結果如下: