【手敲算法】LCA (最近公共祖先) 從理解 到 裸敲

注:代碼經過簡單修改即可適用於 洛谷P3379

代碼洛谷題目範圍 爲 n = 5e5,以下代碼內部使用的 vector push_back 和 遍歷 vector 導致本題部分代碼 Time Limit

可以使用如下代碼替換 add建邊,和 循環

int cnt = -1, head[maxn];
struct node{
    int nxt, to;
};
node Edge[maxn*2];
void add(int u,int v){
	Edge[++cnt].nxt = head[u];
	Edge[cnt].to = v;
	head[u] = cnt;
}

// 主程序內需要初始化
	memset(head, -1, sizeof(head));
	for(int i=head[u];i!=-1;i=Edge[i].nxt){
		int v = Edge[i].to;
    }

 

爲什麼要學習LCA算法

LCA :兩個樹上節點,逐層向上尋找它們的父節點,第一次查找到的 "相同 (祖先) 節點"

查詢q次,每次給出兩個節點,向上遞歸查找,最壞情況下是 O(n* q) 的時間複雜度
要降低查詢的時間複雜度


共有三種算法:(在線:來一個查詢操作,執行一次query;離線:讀入全部查詢操作,一次性計算所有結果)

1. LCA tarjan算法                O(n +q)       離線算法
2. dfs + ST(RMQ) 算法        O(n * logn) 在線算法, RMQ預處理後可以做到  O(1) 查詢
3. 倍增算法                           O(n * logn) 在線算法,O(n * logn)處理數據,O(logn)查詢

  • 算法一:LCA tarjan算法

算法原理

dfs 過程:

1.  查詢當前  節點u,是否包含未遍歷  子節點v,包含繼續向下遞歸查找,回溯時 更新 節點v的 fa[v] = u
2.  遍歷完當前節點所有子樹,
     在查詢集合中,查找是否在  在包括當前節點u  的所有查詢中:
         如果這個查詢(u , v)的另一個節點v  被遍歷過,那麼當前這個查詢(u, v)的答案是   v節點所在並查集裏的根節點


如上圖 查詢 (8 , 11) 的 LCA,因爲2號節點 的 fa[2] 並未更新,還是初始值 2 ,
依次查到 5節點,10節點 11節點時,將它們放入綠框內,(fa[2] = 2不變,
之前所有綠框內的 節點,根據並查集,都會維護在2號節點下,他們所屬並查集的根節點 爲2
查到11節點,這個查詢(8 , 11)LCA因爲 vis[8] 被標記了成 "已遍歷",LCA = Find(8) = 2

可以將 理解爲,節點2  的所有子樹節點按遍歷順序放入一個棧中,棧名是 "棧2",
放入條件爲:某個節點子樹上的 所有節點 都被遍歷過,遍歷 每個  棧2  的子節點,如果它包含在query中
eg:(11,x),那麼判斷query內的 x 是否入棧,入棧則  棧名2 爲他們的 LCA。
所有的 "棧number" 逐級嵌套,按此原理計算所有答案。

代碼:n 個點,m此詢問,root爲根節點 (注:如果題目n範圍過大 5e5,需要換爲鏈式前向星建圖)

鏈式前向星比vector 的 push_back 更快

#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define Clear( x , y ) memset( x , y , sizeof(x) );
#define Qcin() std::ios::sync_with_stdio(false);
using namespace std;
const int maxn = (int)1e5 + 5;
const int inf = 1e9 + 7;
typedef long long ll;
typedef pair <int, int> PII;

int n, m;
int fa[maxn];
int vis[maxn], ans[maxn];
vector <int> G[maxn];

struct node{
	int v, id;
};
vector <node> q[maxn];

void init(){
	for(int i=1 ; i <= n ; i++){
		fa[i] = i;
	}
}
int Find(int x){
	return fa[x] == x ? x : fa[x] = Find(fa[x]);
}

void dfs(int u, int ff){
	for(int v: G[u]){
		if(v == ff)	continue;
		dfs(v, u);
		fa[v] = u;
	}
	for(node U: q[u]){
		if(vis[U.v]) ans[U.id] = Find(U.v);
	}
	vis[u] = 1;
}


int main()
{
	Qcin();
	cin >> n >> m;
	init();
	int u, v;
	for(int i=1 ; i < n ; i++){
		cin >> u >> v;
		G[u].pb(v);
		G[v].pb(u);
	}
	for(int i=1 ; i <= m ; i++){
		cin >> u >> v;
		q[u].pb((node){v, i});
		q[v].pb((node){u, i});
	}
	int root = 1;
	dfs(root, 0);
	for(int i=1 ; i <= m ; i++){
		cout << ans[i] << endl;
	}
	return 0;
}
  • 算法二:dfs + ST(RMQ)

利用 dfs 求得 歐拉序,和RMQ結合 求出 LCA

解釋一下歐拉序 和 dfs序

如圖所示樹:

dfs序:第一次遍歷到 節點x,將其加入 數組末端 1 -> 2 -> 4 -> 3 -> 5 -> 6

歐拉序:遍歷完節點x,回溯時,它的父親節點 fa 再次添加進 數組末端
1 -> 2 -> 4 -> 2 -> 1 -> 3 -> 5 -> 3 -> 6 -> 3 -> 1

dfs 過程中記錄:

  1. 歐拉序 a [ ]
  2. 每個節點,在歐拉序中   第一次出現   的下標 id [ ]
  3. 每個節點在樹上的 深度

 在 使用 ST表(RMQ) 維護原本的 區間最小值,更改min的定義,minFunc() 將 兩段小區間內 dep 較小的 節點 ,更新爲整段大區間的 節點 (不是dep,是節點)

每次查詢 u,v 的 LCA ,直接查詢 id[u], id[v] 即兩節點在 歐拉序中 第一次出現的 位置,圍成的區間 ,此區間內的dep最小的 節點

如圖查找 4 , 6

發現最高的遍歷位置就是它們的 LCA

代碼

#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define Clear( x , y ) memset( x , y , sizeof(x) );
#define Qcin() std::ios::sync_with_stdio(false);
using namespace std;
const int maxn = (int)5e5 + 5;
const int inf = 1e9 + 7;
typedef long long ll;
typedef pair <int, int> PII;
int n, m;
vector <int> G[maxn];
int a[maxn<<1], tot = 0, id[maxn], dep[maxn];
int dp[maxn<<1][30];

int minFunc(int u, int v){
	return dep[u] < dep[v] ? u : v;
}
void init_rmq(){
	for(int i=1 ; i <= tot ; i++){
		dp[i][0] = a[i];
	}
	for(int j=1 ; j < 20 ; j++){
		for(int i=1 ; i + (1<<j) - 1 <= tot ; i++){
			dp[i][j] = minFunc(dp[i][j-1], dp[i+(1<<(j-1))][j-1]);
		}
	}
}
int query(int l, int r){
	int k= (int)(log((double)(r-l+1))/log(2.0));
	int ans = minFunc(dp[l][k], dp[r-(1<<k)+1][k]);
}

void dfs(int u, int fa, int depth){
	a[++tot] = u;
	id[u] = tot;
	dep[u] = depth;
	for(int v: G[u]){
		if(v == fa)	continue;
		dfs(v, u, depth + 1);
		a[++tot] = u;
	}
}

int main()
{
	scanf(" %d %d",&n, &m);
	for(int i = 1 ; i < n ; i++){
		int u, v; scanf(" %d %d",&u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	int root = 1;
	dfs(root, 0, 1);
	init_rmq();
	for(int i=1 ; i <= m ; i++){
		int u, v; scanf(" %d %d",&u, &v);
		u = id[u], v = id[v];
		if(u > v) swap(u, v);
		int ans = query(u, v);
		printf("%d\n",ans);
	}
	return 0;
}

算法三:倍增法

倍增法就是 通過 在樹上,每次向上跳 2 ^ k 的節點數,達到快速找到 LCA的目的

建圖是通過先初始化 p[ i ][ 0 ] 初始化 節點 i第 2 ^ 0 = 1 輩父親節點 ( fa[ i ] 即 節點 i 的 父親節點)

p [ i ] [ j ] 表示 節點 i   的 第 2 ^ j 輩 祖先節點

由此可見 倍增 可以 由 小輩祖先 (eg:2^2輩祖先) 推導 到 高輩祖先 (eg:2^3)

推導公式如下:(u是當前節點——如圖中16節點)

        p[ u ] [ i ] = p[  p[u][i-1]  ] [ i-1 ];

(p[ u ] [ i-1 ] 對應圖中 12節點)


按照 二進制 的增長速度,很快就達到非常大的距離,如此可以很快的 向上跳到LCA

查找LCA規則:

1. 先將兩 u,v中,深度更大 的 點,向上跳,尋找 兩點等深度的 祖先

2. 兩節點相同深度,查看是否是同一節點,true 則輸出,false 則 一起 繼續上跳(跳躍距離相同),查找各自祖先節點————直至它們的祖先節點相同

使深度相同 和 共同上跳

雙節點 上跳規則:

先跳 2 ^ j (初始j = max 20) 長度,發現祖先相同,則不跳

當 j-- 上跳 2 ^ (j = 19)  可能是如下

持續 j-- 當 指向祖先不同時,完成上跳動作

j-- 繼續上跳,直到 j = 0,完成跳躍。到上圖右邊 最終狀態。

注:中間可能有的 j 跳動了,有的沒有跳動,因爲假設到最終態 需要上跳的距離爲 12,則
12 = 2^3 + 2^2
j = 3 ~ 0 中只有 3 和 2 跳動了,1 和 0 ,沒有跳動。

最終 return   7號節點的父親節點  ,這就是它們的 LCA

代碼

#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define Clear( x , y ) memset( x , y , sizeof(x) );
#define Qcin() std::ios::sync_with_stdio(false);
using namespace std;
const int maxn = (int)5e5 + 7;
const int inf = 1e9 + 7;
typedef long long ll;
typedef pair <int, int> PII;
int n, m;
vector <int> G[maxn];
int p[maxn][30], dep[maxn];

void dfs(int u, int fa, int depth){
	dep[u] = depth;
	for(int i = 1 ; (1<<i) <= dep[u] ; i++){
		p[u][i] = p[p[u][i-1]][i-1];
	}
	for(int v: G[u]){
		if(v == fa)	continue;
		p[v][0] = u;
		dfs(v, u, depth+1);
	}
}

int Lca(int a, int b){
	if(dep[a] < dep[b]) swap(a, b);

	for(int i = 20 ; i >= 0 ; i--){
		if(dep[a] - (1<<i) >= dep[b]){
			a = p[a][i];
		}
	}

	if(a == b) return a;

	for(int i = 20 ; i >= 0 ; i--){
		if(p[a][i] != p[b][i]){
			a = p[a][i];
			b = p[b][i];
		}
	}
	return p[a][0];
}

int main()
{
	scanf(" %d %d %d",&n, &m);
	for(int i = 1 ; i < n ; i++){
		int u, v; scanf(" %d %d",&u, &v);
		G[u].pb(v);
		G[v].pb(u);
	}
	int root = 1;
	dfs(root, 0, 1);
	for(int i = 1 ; i <= m ; i++){
		int u, v; scanf(" %d %d",&u, &v);
		int ans = Lca(u, v);
		printf("%d\n",ans);
	}
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章