概述:
參考神犇yyb的博客
問題:如何做到\(O(nlogn)-O(1)\)複雜度求解\(k\)次祖先?
常規倍增是\(O(nlogn)-O(logn)\)的,重鏈剖分是\(O(nlogn)-O(logn)\)的,歐拉序st表能在\(O(nlogn)-O(1)\)複雜度內求兩點LCA,但並不能查出k次祖先是誰
長鏈剖分
方法和樹剖十分類似,代碼也幾乎相同,但我們每次不是挑子樹最大的兒子作爲重鏈,而是挑最大深度最大的兒子作爲重鏈
長鏈剖分有如下性質:
1.所有重鏈長度之和是\(O(n)\)級別
顯然每個點最多在一條重鏈內
2.如果x和k次祖先y不在同一重鏈內,那麼y點長鏈的鏈長(所在重鏈末尾節點到它的距離),一定大於等於k
如果小於k,那麼x-y這條鏈更長,與長鏈剖分的前提——挑最大深度的兒子相悖
繼續考慮怎麼利用性質
這個做法需要分類討論
在常規的重鏈剖分中,如果k級祖先和它在同一重鏈內(用深度判斷\(dep[top_{x}]-dep_[x]\ge k\)),我們可以在\(O(1)\)時間找到k級祖先(維護重鏈剖分序,同一重鏈上的點一定連續)
把這個想法拓展到長鏈剖分,我們去掉了x與k級祖先在同一重鏈上的情況
現在x和k級祖先不在同一重鏈上
有一個想法:我們找到x點的\(r\)級祖先,使得\(r>k/2\),我們能夠\(O(1)\)時間內求出x點的\(r\)級祖先z。然後考慮z的\(k-r\)級祖先,用上面的方法提到的check一下。如果不彳亍,說明z和y不在同一鏈內,且z的鏈頭T深度比y大
由長鏈剖分性質1可知,重鏈長度之和一定是\(O(n)\)級別,我們對於每個鏈頭,暴力處理出跳\([1,鏈長]\)長度時的祖先!!容易發現這個預處理複雜度是\(O(n)\)的
而我們找到的\(r>k/2\),利用上面預處理出的數組就可以\(O(1)\)找到y了
還剩一個問題,這個\(r\)級祖先怎麼選,才能\(O(1)\)找到呢?倍增就行了!我們令r一定是2的冪次,對於詢問k,我們取k的最高位\(highbit(k)\)即可
總結一下:每個點倍增預處理\(O(nlogn)\),長鏈剖分\(O(n)\),鏈頭的處理\(O(n)\),每次詢問\(O(1)\)
幾道題
給你一棵樹,對於每個點x,子樹內所有點到它都有一個距離,詢問出現次數最多的距離,輸出這個距離(點數相同時輸出最小的距離)
首先有個非常裸的重鏈剖分dsu做法,每次先處理輕子樹,然後把輕子樹桶信息清空,再進入重子樹,保留桶信息,再遍歷一遍輕子樹把信息丟進桶,然後處理答案。我們只有添加or清空操作,用桶維護,記錄最大值並更新。時間\(O(nlogn)\)
#include<bits/stdc++.h>
using namespace std;
#define r(x) read(x)
#define ll long long
#define it set<string>::iterator
const int N1=1e6+7;
template <typename _T> void read(_T &ret)
{
ret=0; _T fh=1; char c=getchar();
while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
ret=ret*fh;
}
struct EDGE{
int to[N1*2],nxt[N1*2],head[N1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u]; head[u]=cte; }
}e;
int n,ma;
int sz[N1],son[N1],dep[N1],bar[N1],ans[N1];
void dfs0(int u,int dad)
{
sz[u]=1;
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
dep[v]=dep[u]+1;
dfs0(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]]) son[u]=v;
}
}
void push(int x,int w)
{
bar[dep[x]]+=w;
if(bar[dep[x]]>bar[ma]) ma=dep[x];
else if(bar[dep[x]]==bar[ma]&&dep[x]<ma) ma=dep[x];
}
void inbar(int u,int dad,int w)
{
push(u,w);
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
inbar(v,u,w);
}
}
void dfs1(int u,int dad)
{
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j];
if(v==dad||v==son[u]) continue;
dfs1(v,u);
inbar(v,u,-1);
}
ma=0;
if(son[u]){
dfs1(son[u],u);
}
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j];
if(v==dad||v==son[u]) continue;
inbar(v,u,1);
}
push(u,1);
ans[u]=ma-dep[u];
}
int main(){
// freopen("1.in","r",stdin);
// freopen(".out","w",stdout);
read(n);
int x,y;
for(int i=1;i<n;i++) read(x), read(y), e.ae(x,y), e.ae(y,x);
dep[1]=1; dfs0(1,-1);
dfs1(1,-1);
// for(int i=1;i<=n;i++) printf("%d\n",son[i]);
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
return 0;
}
回到長鏈剖分,我們考慮在長鏈剖分序上DP,同一條鏈上的點一定連續。
考慮在長鏈剖分序上DP,我們記\(f(x,j)\)表示距離x點距離爲j的點個數。每個點在繼承重兒子信息時,指針移位即可(它們一定連續)。然後暴力合併輕兒子記錄的信息。因爲每條長鏈只會在鏈頭被暴力合併一次,總時間複雜度\(O(n)\)
#include<bits/stdc++.h>
using namespace std;
#define r(x) read(x)
#define ll long long
#define it set<string>::iterator
const int N1=1e6+7;
template <typename _T> void read(_T &ret)
{
ret=0; _T fh=1; char c=getchar();
while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
ret=ret*fh;
}
struct EDGE{
int to[N1*2],nxt[N1*2],head[N1],cte;
void ae(int u,int v)
{ cte++; to[cte]=v, nxt[cte]=head[u]; head[u]=cte; }
}e;
int n;
int len[N1],son[N1],*f[N1],ans[N1],*tot;
void dfs0(int u,int dad)
{
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad) continue;
dfs0(v,u);
if(len[v]>len[son[u]]) son[u]=v;
}
len[u]=len[son[u]]+1;
}
int dfs1(int u,int dad)
{
int ma=0;
if(son[u]) f[son[u]]=tot++, ma=dfs1(son[u],u)+1;
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad||v==son[u]) continue;
f[v]=tot++;
v=dfs1(v,u);
}
for(int j=e.head[u];j;j=e.nxt[j]){
int v=e.to[j]; if(v==dad||v==son[u]) continue;
for(int i=0;i<len[v];i++){
f[u][i+1]+=f[v][i];
if(f[u][i+1]>f[u][ma]) ma=i+1;
else if(f[u][i+1]==f[u][ma]&&i+1<ma) ma=i+1;
}
}
f[u][0]=1;
if(f[u][0]>f[u][ma]) ma=0;
else if(f[u][0]==f[u][ma]&&0<ma) ma=0;
// ans[u]=ma;
ans[u]=ma;
return ma;
}
int main(){
// freopen("1.in","r",stdin);
// freopen(".out","w",stdout);
read(n);
int x,y;
for(int i=1;i<n;i++) read(x), read(y), e.ae(x,y), e.ae(y,x);
dfs0(1,-1);
tot=(int*)malloc(N1*sizeof(int));
f[1]=tot++;
dfs1(1,-1);
// for(int i=1;i<=n;i++) printf("%d\n",son[i]);
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
return 0;
}