RMQ(Range Minimum/Maximum Query),即區間最值查詢,是指這樣一個問題:對於長度爲n的數列A,回答若干次詢問RMQ(i,j),返回數列A中下標在區間[i,j]中的最小/大值。
本文介紹一種比較高效的ST算法解決這個問題。ST(Sparse Table)算法可以在O(nlogn)時間內進行預處理,然後在O(1)時間內回答每個查詢。
1)預處理
設A[i]是要求區間最值的數列,F[i, j]表示從第i個數起連續2^j個數中的最大值。(DP的狀態)
例如:
A數列爲:3 2 4 5 6 8 1 2 9 7
F[1,0]表示第1個數起,長度爲2^0=1的最大值,其實就是3這個數。同理 F[1,1] = max(3,2) = 3, F[1,2]=max(3,2,4,5) = 5,F[1,3] = max(3,2,4,5,6,8,1,2) = 8;
並且我們可以容易的看出F[i,0]就等於A[i]。(DP的初始值)
我們把F[i,j]平均分成兩段(因爲F[i,j]一定是偶數個數字),從 i 到i + 2 ^ (j - 1) - 1爲一段,i + 2 ^ (j - 1)到i + 2 ^ j - 1爲一段(長度都爲2 ^ (j - 1))。於是我們得到了狀態轉移方程F[i, j]=max(F[i,j-1], F[i + 2^(j-1),j-1])。
2)查詢
假如我們需要查詢的區間爲(i,j),那麼我們需要找到覆蓋這個閉區間(左邊界取i,右邊界取j)的最小冪(可以重複,比如查詢1,2,3,4,5,我們可以查詢1234和2345)。
因爲這個區間的長度爲j - i + 1,所以我們可以取k=log2( j - i + 1),則有:RMQ(i, j)=max{F[i , k], F[ j - 2 ^ k + 1, k]}。
舉例說明,要求區間[1,5]的最大值,k = log2(5 - 1 + 1)= 2,即求max(F[1, 2],F[5 - 2 ^ 2 + 1, 2])=max(F[1, 2],F[2, 2]);
void ST(int n) {
for (int i = 1; i <= n; i++)
dp[i][0] = A[i];
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
dp[i][j] = max(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
}
}
}
int RMQ(int l, int r) {
int k = 0;
while ((1 << (k + 1)) <= r - l + 1) k++;
return max(dp[l][k], dp[r - (1 << k) + 1][k]);
}
LCA(樹上最近公共祖先)
用上面的算法的話時間複雜度是 O(nlogn) 預處理,O(1) 在線查詢。
首先引入dfs序。
例如,上圖這棵樹的一個dfs序爲 8,5,9,5,8,4,6,15,6,7,6,4,10,11,10,16,3,16,12,16,10,2,10,4,8,1,14,1,13,1,8
9-6的路徑上所有的點在上面的數組中可以找到一個連續數列(9 5 8 4 6),其中出現的深度最小的節點就是他們的LCA。
a:從樹的根開始,將樹看成一個無向圖進行深度優先遍歷,記錄下每次到達的頂點,第一個頂點爲樹根root,
每經過一條邊都記錄它的端點,每條邊都恰好經過兩次.用數組ver記錄結點。
b:記錄first數組和deep數組,first數組記錄在深度優先遍歷時結點第一次出現的位置。deep數組記錄結點的深度
void dfs(int u,int dep)
{
vis[u]=true;
ver[++tot]=u;
first[u]=tot;
deep[tot]=dep;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(!vis[v])
{
dfs(v,dep+1);
ver[++tot]=u;
deep[tot]=dep;
}
}
}
可以發現,我們通過dfs記錄結點後,當我們要查詢結點u,v時,我們可以在結點的數組中找到u結點第一次出現的位置first[u] 和v結點第一次出現的位置 first[v],而他們位置之間的結點便是u到v的DFS順序,雖然其中可能包含u或v的後代,但其中深度最小的還是u和v的最近公共祖先。因此可以用ST表記錄與
結點數組相對應的深度序列的區間最小值下標,將lca轉化爲RMQ問題。
void ST(int n)
{
for(int i=1;i<=n;i++)
dp[i][0]=i;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
{
int a=dp[i][j-1];int b=dp[i+(1<<(j-1))][j-1]; //記錄其中結點深度最小的結點的位置
dp[i][j]=deep[a]<deep[b]?a:b;
}
}
尋找LCA(u,v)時,先尋找first[u],first[v],將[first[u],first[v]]間的最小值的deep找出,該值下標所對應的結點即爲LCA(u,v)。
即當first[u]>first[v]時,LCA(T,u,v)=RMQ(deep,R[v],R[u]),否則LCA(T,u,v) = RMQ(deep,R[u],R[v]).
int RMQ(int l,int r)
{
int k=0;
while(1<<(k+1)<=r-l+1)
k++;
int a=dp[l][k],b=dp[r-(1<<k)+1][k];
return deep[a]<deep[b]?a:b;
}
int LCA(int u,int v)
{
int x=first[u],y=first[v];
if(x>y)swap(x,y);
int res=RMQ(x,y);
return ver[res];
}
總結:
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int MAX=10009;
int T,n,a,b;
int head[MAX],cnt=0;
int tot=0;
int dp[MAX*2][25]; //ST表
int deep[MAX*2]; //記錄節點深度
int ver[MAX*2]; //記錄節點編號
int first[MAX]; //記錄點第一次出現的位置
bool vis[MAX];
bool isroot[MAX]; //判斷根節點的數組
struct Edge{
int to,next;
}edge[MAX*2];
inline void add(int u,int v)
{
edge[cnt].to=v;
edge[cnt].next=head[u];
head[u]=cnt++;
}
void dfs(int u,int dep)
{
vis[u]=true; //訪問過該節點
ver[++tot]=u; //將該節點記錄在ver中
first[u]=tot; //記錄結點u第一次出現的位置
deep[tot]=dep; //記錄深度
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(!vis[v])
{
dfs(v,dep+1);
ver[++tot]=u;
deep[tot]=dep;
}
}
}
void ST(int n)
{
for(int i=1;i<=n;i++) //初始化
dp[i][0]=i;
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
{
int a=dp[i][j-1];int b=dp[i+(1<<(j-1))][j-1]; //記錄其中結點序列深度最小的結點的編號
dp[i][j]=deep[a]<deep[b]?a:b;
}
}
int RMQ(int l,int r)
{
int k=0;
while(1<<(k+1)<=r-l+1) //求區間長度以二爲底的對數
k++;
int a=dp[l][k],b=dp[r-(1<<k)+1][k];
return deep[a]<deep[b]?a:b;
}
int LCA(int u,int v)
{
int x=first[u],y=first[v];
if(x>y)swap(x,y);
int res=RMQ(x,y);
return ver[res];
}
void init()
{
memset(head,-1,sizeof(head)),cnt=0;tot=0;
memset(isroot,true,sizeof(isroot));
memset(vis,false,sizeof(vis));
memset(dp,0,sizeof(dp));
memset(deep,0,sizeof(deep));
memset(first,0,sizeof(first));
memset(ver,0,sizeof(ver));
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
init();
for(int i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
isroot[b]=false;
add(a,b);
add(b,a);
}
int root;
for(int i=1;i<=n;i++)
{
if(isroot[i])
{
root=i;break;
}
}
dfs(root,1);
ST(2*n-1);
scanf("%d%d",&a,&b);
printf("%d\n",LCA(a,b));
}
return 0;
}