樹鏈剖分 學習筆記

寫在前面,算法筆記只是針對個人複習用,要想學習這個算法還是要看更詳細的博客或視頻講解。

樹鏈剖分

基本概念不講,講一下各個函數的作用,以及基本實現原理;

首先要知道樹鏈剖分的作用就是把樹上問題轉化爲線性問題;

線性問題總所周知就有區間問題,有區間問題就有線段樹這種東西;

所以樹鏈剖分就是先把樹拆成一條條鏈(線性),然後用線段樹維護區間就行;

怎麼拆是最難的,我覺得搞懂dfs序是前提,因爲就是根據dfs序進行拆分的;

這裏不解釋了;

  1. dfs1 這個函數就是相當於預處理,預處理出樹的每個結點的 deep(深度),fa(父親),size(子樹大小),son(重兒子編號);
void dfs1(int p,int fat){
	size[p]=1,fa[p]=fat;
	int mx=-1;
	for(int i=head[p];~i;i=edge[i].nex){
		int q=edge[i].to;
		if(q!=fat){
			deep[q]=deep[p]+1;
			dfs1(q,p);
			size[p]+=size[q];
			if(size[q]>mx) son[p]=q,mx=size[q]; 
		}
	}
}
  1. dfs2 這個函數就是dfs序進行拆分了,然後求出每個點的 id(新編號),w(新編號點的權值),top(每個結點所在鏈的頂端)
void dfs2(int p,int topp){
	id[p]=++tot,w[tot]=a[p],top[p]=topp;
	if(!son[p]) return;//沒有重兒子,也就是葉子結點 
	dfs2(son[p],topp);//先處理重兒子 
	for(int i=head[p];~i;i=edge[i].nex){
		int q=edge[i].to;
		if(q==fa[p]||q==son[p]) continue;
		dfs2(q,q);
	}
}
  1. 然後就是最重要的詢問和修改函數了,因爲已經轉化爲線性了,所以只要求出要維護的區間的就行(用線段樹維護區間),然後就可以了;怎麼求出要維護的區間,就跟求LCA的思路差不多;
LL query_chain(int x,int y){
	LL ans=0;
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]]) swap(x,y);//保證x結點的頂點更深 
		(ans+=query(id[top[x]],id[x],1))%=p;
		x=fa[top[x]]; 
	}
	if(deep[x]<deep[y]) swap(x,y);
	(ans+=query(id[y],id[x],1))%=p;
	return ans;
}
LL query_tree(int x){
	LL ans=0;
	(ans+=query(id[x],id[x]+size[x]-1,1))%=p;
	return ans;
}

全部代碼來自這個題:

洛谷·P3384 【模板】輕重鏈剖分

代碼:

#include<bits/stdc++.h>
#define LL long long
#define pa pair<int,int>
#define ls k<<1
#define rs k<<1|1
#define inf 0x3f3f3f3f
using namespace std;
const int N=100100;
const int M=50100;
const LL mod=10007;
LL a[N],p,w[N];
int head[N],cnt,n,m,r;
int size[N],fa[N],son[N],deep[N];
int top[N],id[N],tot;
struct Node{
	int to,nex;
}edge[N*2];
struct node{
	int l,r;
	LL w,lz;
}tr[N*4];
void add(int p,int q){edge[cnt].to=q,edge[cnt].nex=head[p],head[p]=cnt++;}
void dfs1(int p,int fat){
	size[p]=1,fa[p]=fat;
	int mx=-1;
	for(int i=head[p];~i;i=edge[i].nex){
		int q=edge[i].to;
		if(q!=fat){
			deep[q]=deep[p]+1;
			dfs1(q,p);
			size[p]+=size[q];
			if(size[q]>mx) son[p]=q,mx=size[q]; 
		}
	}
}
void dfs2(int p,int topp){
	id[p]=++tot,w[tot]=a[p],top[p]=topp;
	if(!son[p]) return;//沒有重兒子,也就是葉子結點 
	dfs2(son[p],topp);//先處理重兒子 
	for(int i=head[p];~i;i=edge[i].nex){
		int q=edge[i].to;
		if(q==fa[p]||q==son[p]) continue;
		dfs2(q,q);
	}
}
////
void pp(int k){
	tr[k].w=tr[ls].w+tr[rs].w;
	tr[k].w%=p;
}
void pd(int k){
	if(tr[k].lz){
		(tr[ls].lz+=tr[k].lz)%=p,(tr[rs].lz+=tr[k].lz)%=p;
		(tr[ls].w+=1ll*(tr[ls].r-tr[ls].l+1)*tr[k].lz)%=p;
		(tr[rs].w+=1ll*(tr[rs].r-tr[rs].l+1)*tr[k].lz)%=p;
	}
	tr[k].lz=0;
}
void build(int l,int r,int k){
	tr[k].lz=tr[k].w=0,tr[k].l=l,tr[k].r=r;
	if(l==r){
		tr[k].w=(w[l]%p);
		return;
	}
	int d=(l+r)>>1;
	build(l,d,ls);
	build(d+1,r,rs);
	pp(k);
}
void update(int l,int r,LL w,int k){
	if(tr[k].l>=l&&tr[k].r<=r){
		(tr[k].lz+=w)%=p;
		(tr[k].w+=1ll*(tr[k].r-tr[k].l+1)*w)%=p;
		return;
	}
	pd(k);
	int d=(tr[k].l+tr[k].r)>>1;
	if(l<=d) update(l,r,w,ls);
	if(r>d) update(l,r,w,rs);
	pp(k);
}
LL query(int l,int r,int k){
	LL ans=0;
	if(tr[k].l>=l&&tr[k].r<=r) return tr[k].w%p;
	pd(k);
	int d=(tr[k].l+tr[k].r)>>1;
	if(l<=d) (ans+=query(l,r,ls))%=p;
	if(r>d) (ans+=query(l,r,rs))%=p;
	return ans%p;
}
///
LL query_chain(int x,int y){
	LL ans=0;
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]]) swap(x,y);//保證x結點的頂點更深 
		(ans+=query(id[top[x]],id[x],1))%=p;
		x=fa[top[x]]; 
	}
	if(deep[x]<deep[y]) swap(x,y);
	(ans+=query(id[y],id[x],1))%=p;
	return ans;
}
LL query_tree(int x){
	LL ans=0;
	(ans+=query(id[x],id[x]+size[x]-1,1))%=p;
	return ans;
}
void update_chain(int x,int y,LL z){
	while(top[x]!=top[y]){
		if(deep[top[x]]<deep[top[y]]) swap(x,y);//保證x結點的頂點更深 
		update(id[top[x]],id[x],z%p,1);
		x=fa[top[x]]; 
	}
	if(deep[x]<deep[y]) swap(x,y);
	update(id[y],id[x],z%p,1);
}
void update_tree(int x,LL z){
	update(id[x],id[x]+size[x]-1,z%p,1);
} 
int main(){
	memset(head,-1,sizeof(head));
	scanf("%d%d%d%lld",&n,&m,&r,&p);
	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
	for(int i=1;i<n;i++){
		int x,y;scanf("%d%d",&x,&y);
		add(x,y),add(y,x);
	}
	dfs1(r,0);
	dfs2(r,r);
	build(1,n,1);//建線段樹 
	for(int i=1;i<=m;i++){
		int op;scanf("%d",&op);
		if(op==1){
			int x,y;LL z;scanf("%d%d%lld",&x,&y,&z);
			update_chain(x,y,z);
		}
		else if(op==2){
			int x,y;scanf("%d%d",&x,&y);
			printf("%lld\n",query_chain(x,y));
		}
		else if(op==3){
			int x;LL z;scanf("%d%lld",&x,&z);
			update_tree(x,z); 
		}
		else{
			int x;scanf("%d",&x);
			printf("%lld\n",query_tree(x));
		}
	}
	return 0;
}	
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章