算法描述
根據上一個博客介紹的dfs序以及歐拉序能夠把樹上的點轉爲線性的區間點,從而可以用區間的數據結構去維護。根據歐拉序的定義,我們會發現樹上任意兩點的第一次出現位置之間必然夾帶着lca的點,至於爲什麼可以畫圖理解一下,因爲我們生成這個歐拉序時每次回溯就加一個點,而任意兩點之間的搜索樹一定是從lca開始往下搜,然後回溯再轉而去搜另外一個點,所以lca就生成再兩點的時間戳之間了。
於是我們維護完歐拉序後我們可以得到序列中深度最小的那個點必然是lca,這兩點之間不會再夾帶深度更小的點了,原因和上述藍字一致。至此,整個問題從樹上求LCA轉爲求區間序列深度最小的點,即RMQ問題,對於這個算法有個預處理,查詢的高效算法:ST表(基於動態規劃和倍增思想)。這份博客有關於ST表求RMQ的講解:戳這裏。
不過我們的狀態要重新設計一下,設表示起點爲i,跳步長的深度最小的點,這樣設計的原因是我們要維護的是深度最小值,但要求的是最小值的那個點,不這樣子還要多一次哈希,感覺沒啥必要,直接找這個點,狀態轉移的時候哈希到深度就可以了。(其實就是少了n長度的空間)。其他的和st表的基本操作一致。
實現
#include <bits/stdc++.h>
using namespace std;
const int maxnn = (int)5e5+5;
const int maxnm = (int)1e6+5;
/**
* 利用歐拉序中兩點的lca會包含在兩個點的in之中的性質
* 查詢lca(u, v)相當於查min(in[u], in[v])
* 區間最值查詢可以用線段樹 理想複雜度是log和倍增的複雜度一樣 但常數大
* 這裏用st表化爲常數級區間最值查詢
* 這是個在線算法
*/
int _to[maxnm], _next[maxnm], head[maxnn], cnt;
int Log[maxnm], Mi[21]; //注意這個log的大小,避免越界,例如洛谷越界有時候不會報re而是wa
int in[maxnn], seq[maxnm], deep[maxnn], id;
int st[maxnm][21]; //生成2*n-1的歐拉序起點爲i,步長爲2^j的序列深度最小的點
int n, m, s;
void edge_add(int u, int v) {
_to[cnt] = v;
_next[cnt] = head[u];
head[u] = cnt++;
}
void init() {
memset(head, -1, sizeof(head));
memset(deep, 0, sizeof(deep));
cnt = id = 0;
int x, y;
for (int i = 1; i < n; i++) {
scanf("%d%d", &x, &y);
edge_add(x, y);
edge_add(y, x);
}
}
void dfs(int cur, int parent) {
seq[++id] = cur;
in[cur] = id;
deep[cur]=deep[parent]+1;
for (int i = head[cur]; ~i; i=_next[i]) {
int v = _to[i];
if (v == parent) continue;
dfs(v, cur);
seq[++id] = cur;
}
}
void rmq_init() {
Log[0] = -1;
for (int i = 1; i <= id; i++)
Log[i] = Log[i>>1] + 1; //另外一種遞推式:Log[i] = Log[i-1]+(1<<Log[i-1]==i) log(i)+1
Mi[0] = 1;
for (int i = 1; i <= 20; i++)
Mi[i] = Mi[i-1]<<1;
for (int i = 1; i <= id; i++)
st[i][0] = seq[i];
for (int j = 1; j <= Log[id]; j++) {
for (int i = 1; i+Mi[j] <= id+1; i++) {
//st[i][j] = min(st[i][j-1], st[i+Mi[j-1]][j-1])
if (deep[st[i][j-1]] <= deep[st[i+Mi[j-1]][j-1]]) {
st[i][j] = st[i][j-1];
} else {
st[i][j] = st[i+Mi[j-1]][j-1];
}
}
}
}
int query(int u, int v) {
if (u > v) swap(u, v);
int len = v - u + 1;
int k = Log[len];
int dc1 = st[u][k], dc2 = st[v-Mi[k]+1][k];
return deep[dc1] > deep[dc2] ? dc2 : dc1;
}
int main() {
scanf("%d%d%d", &n, &m, &s);
init();
dfs(s, 0);
rmq_init();
int x, y;
for (int i = 0; i < m; i++) {
scanf("%d%d", &x, &y);
printf("%d\n", query(in[x], in[y]));
}
return 0;
}