題意:給定一棵樹,查詢時給定兩個點,求出兩個點的距離。
暴力做肯定超時的。我的做法是採用lca(最近公共祖先)的離線算法,即tarjan算法(據說Tarjan提出了很多算法,可能還有很多tarjan算法),算法裏用到了並查集。在輸入完所有查詢之後,在求出答案。tarjan算法的做法是:一開始vis數組初始化爲0,從樹根開始遞歸往下對點進行染色,剛到一個點的時候將vis取爲-1,在繼續遞歸;遍歷完子節點返回之後vis變爲1。在vis變爲1之前,檢索一下當前節點的所有查詢,設查詢中的另外一個節點爲To,如果vis[To]==0,就continue,因爲To還沒有處理,不知道它的信息;如果vis[To]==-1,說明To被訪問了一次,但是還沒有返回到,這意味着To是當前節點的祖先,因此To就是當前節點的最近公共祖先;如果vis[To]==1,說明To已經處理完了,這時候並查集就派上用場了。在遞歸時,當一個節點處理完返回到父親那裏時,就把父親變成其所在集合的代表元素。在剛纔討論到vis[To]==1的情況中,可以知道find(To)(即To所在集合的代表元素)就是To和當前節點的最近公共祖先了(這個可以畫圖演算一下)。在這道題中,我們一開始可以用一個簡單的遞歸算出每個點到根節點的距離dis[i]。那麼對於一個查詢的兩個點fir和sec,它們的距離就是dis[fir]-dis[lca]+dis[sec]-dis[lca],lca是fir和sec的最近公共祖先。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<set>
#include<climits>
#include<queue>
#include<vector>
#include<map>
using namespace std;
struct node
{
int to,id;
node(int t,int i)
{
to=t;
id=i;
}
node(){}
};
const int maxn=50005;
vector<node>vec[maxn];
vector<pair<int,int>>query;
int father[maxn],fir[maxn<<1],nxt[maxn<<1],vv[maxn<<1],val[maxn<<1],dis[maxn],ans[75005],e;
int vis[maxn];//0 means it's white,-1 means it's grey, 1 means it's black
int findn(int n)
{
if(n!=father[n]) father[n]=findn(father[n]);
return father[n];
}
void add(int a,int b,int c,int i)
{
vv[e]=b;
val[e]=c;
nxt[e]=fir[a];
fir[a]=e++;
}
void get_height(int sroot,int dist)
{
vis[sroot]=1;
dis[sroot]=dist;
for(int i=fir[sroot];i!=-1;i=nxt[i])
{
int v=vv[i];
if(!vis[v])
{
get_height(v,dist+val[i]);
}
}
}
void dfs(int cur,int fa)
{
vis[cur]=-1;
for(int i=fir[cur];i!=-1;i=nxt[i])
{
int v=vv[i];
if(!vis[v])
{
dfs(v,cur);
father[v]=cur;
}
}
int size=vec[cur].size();
for(int i=0;i<size;i++)
{
node nxt=vec[cur][i];
if(!vis[nxt.to]) continue;
if(-1==vis[nxt.to])
{
ans[nxt.id]=nxt.to;
}
else if(1==vis[nxt.to])
{
ans[nxt.id]=findn(nxt.to);
}
}
vis[cur]=1;
}
int main()
{
#pragma comment(linker, "/STACK:102400000,102400000")//此代碼需要擴棧,可能在遞歸時耗的內存有點大
int n;
while(scanf("%d",&n)!=EOF)
{
for(int i=0;i<=n;i++)
{
father[i]=i;
fir[i]=-1;
vis[i]=0;
vec[i].clear();
}
e=0;//important
int a,b,c;
for(int i=0;i<n-1;i++)
{
scanf("%d%d%d",&a,&b,&c);
add(a,b,c,i);
add(b,a,c,i);
}
get_height(0,0);
int q;
scanf("%d",&q);
for(int i=0;i<q;i++)
{
scanf("%d%d",&a,&b);
vec[a].push_back(node(b,i));
vec[b].push_back(node(a,i));
query.push_back(make_pair<int,int>(a,b));
}
for(int i=0;i<=n;i++) vis[i]=0;
dfs(0,0);
int size=query.size();
for(int i=0;i<size;i++)
{
int fir=query[i].first;
int sec=query[i].second;
int lca=ans[i];
int distance=abs(dis[lca]-dis[fir])+abs(dis[lca]-dis[sec]);
printf("%d\n",distance);
}
}
}