【THUPC2019】不用找的樹 / tree(樹分塊)(樹形DP)

傳送門


是我沒見過的樹分塊姿勢,出題人寫的題解跟****一樣,只有會這道題的人才看得懂,反正我沒看懂,我看了std那8K難以形容的代碼才知道題解真就NM亂寫。

嚴格來說,只有分塊的具體方法那裏是亂寫,但是那裏亂寫直接導致了後續的無法理解。

出題人:反正預計也不會有人想去寫,題解亂寫就行了。

那亂寫說是垃圾也就不爲過了。

所以這裏來一個我的視角的題解。

更離譜的是樹上距離的定義居然是路徑上點數而不是邊數,答案相差的值需要統計距離一個點不超過 dd 的點的數量才能算出來,又要寫個數據結構,所以好點的辦法就是直接算出來答案,於是一大堆細節和一般寫的樹上處理都不一樣,真的就TM離譜。

題解:

首先考慮一個暴力,只有一個詢問,怎麼 O(n)O(n) 處理出答案。

容易注意到我們只需要求出樹上每個點作爲了多少個路徑的LCA即可,一次dfs算出子樹中有多少個 AA 集合的點,然後一次dfs在每個 BB 集合的點處統計答案。

這個暴力複雜度只和樹上的點數相關。

考慮樹分塊,這裏題解也是在TM瞎說,出題人****。

我們希望分塊後各個塊的相鄰關係也是樹形結構,考慮如下的分塊策略:

塊大小往 O(n)O(\sqrt n) 靠,這是常識。

一個點可能存在於多個塊中,每個塊最多允許兩個節點存在於其他塊中(題解裏面根本沒有描述這個細節,直接相鄰兩個字就帶過了,NM重合真就當相鄰了?)。每個塊一定有一個節點深度最高,稱爲 top,另一個節點我們要求其在塊中沒有兒子,且它可能存在於另外的塊中(都是作爲top),這個節點稱爲bot,注意一個塊如果是葉子塊,我們允許其沒有bot。

這樣塊與塊之間的相鄰關係仍然是樹形。

定義鄰域域表示在某個塊中距離某個點不超過 dd 的節點集合。

那麼一個領域可以拆到各個快裏面,且除了中心所在的塊,其他塊中的鄰域中心都是top或bot。

考慮將詢問表示成這幾種詢問的結果之和。

  1. 不同塊中的鄰域的距離之和。
  2. 同一塊內,且兩個中心處於兩個端點(可能處於同一個端點)的領域的距離之和。
  3. 同一塊內,一個或兩個中心不處於端點的鄰域的距離之和。

首先考慮第三種詢問,每個大詢問只會導致 O(1)O(1) 個這種小詢問,在塊內 O(n)O(\sqrt n) DP一下即可。

考慮第二種詢問,由於一個塊只有兩個端點,那麼一個塊需要處理的本質不同的詢問只有 Θ(mxdep2)O(n2)=O(n)\Theta(mxdep^2)\leq O(\sqrt n^2)=O(n),對於 O(n)O(\sqrt n) 個塊,可以每個暴力處理出來答案,我寫的是最後離線每個塊搞一搞。

然後是第一種詢問,由於塊之間的相鄰關係是樹形,樹形DP一下即可。

不建議寫,但如果已經下定決心要寫也不要放棄,大力寫總是有奇蹟的(雖然這個奇蹟花了三天)。


代碼:

#include<bits/stdc++.h>
#define ll long long
#define re register
#define cs const
#define int ll

namespace IO{

inline char gc(){
	static cs int Rlen=1<<22|1;static char buf[Rlen],*p1,*p2;
	return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;
}template<typename T>T get_integer(){
	char c;bool f=false;while(!isdigit(c=gc()))f=c=='-';T x=c^48;
	while(isdigit(c=gc()))x=((x+(x<<2))<<1)+(c^48);return f?-x:x;
}inline int gi(){return get_integer<int>();}

char obuf[(int)(3e7+7)],*oh=obuf,ch[23];
template<typename T>void print(T a,char c){
	if(a<0)*oh++='-',a=-a;int tl=0;
	do ch[++tl]=a%10;while(a/=10);
	while(tl)*oh++=ch[tl--]^48;*oh++=c;
}struct obuf_flusher{~obuf_flusher(){fwrite(obuf,1,oh-obuf,stdout);}}Flusher;

}using namespace IO;

using std::cerr;

cs int N=1e5+7,BN=3e3+7;

int n,m,B;ll ans[N];
std::vector<int> G[N];
int fa[N],bt[N],sz[N];
int tr[N],bel[N],ict;
int q[N],ql,qr,bct;

struct Qry{int d0,d1,d,id;};
struct Block{
	int idl,sz,top,bot,*fa,*ch;
	int *d[2],mxd[2],*dct[2],*dsm[2][2];
	std::vector<Qry> qry;
	inline int id(int x){return tr[x]-idl;}
	void build(int tp,int bt,int bid){
		top=tp,bot=bt;q[0]=tp;
		tr[tp]=idl=ict;
		while(ql!=qr){
			int u=q[++ql];
			bel[tr[u]=++ict]=bid;
			if(::sz[u]>1)
				for(int v:G[u])
					q[++qr]=v;
		}sz=qr+1;
		fa=new int[sz];ch=new int[sz+1];
		memset(fa,0,sizeof(int)*sz);
		memset(ch,0,sizeof(int)*(sz+1));
		int p=0;for(int re i=1;i<sz;++i){
			fa[i]=tr[::fa[q[i]]]-idl;
			while(p<fa[i])ch[++p]=i;
		}while(p<sz)ch[++p]=sz;
		d[0]=new int[sz];d[0][0]=0;
		for(int re i=1;i<sz;++i)
			d[0][i]=d[0][fa[i]]+1;
		mxd[0]=d[0][sz-1]+1;
		cal_dc(0);cal_dsm(0,0);
		if(bt){
			d[1]=new int[sz];
			memset(d[1],0,sizeof(int)*sz);
			ql=0;d[1][q[qr=1]=id(bt)]=1;
			while(ql<qr){
				int u=q[++ql],D=(u==id(bt)?0:d[1][u])+1;
				if(!d[1][fa[u]])d[1][q[++qr]=fa[u]]=D;
				for(int re i=ch[u];i<ch[u+1];++i)
					if(!d[1][i])d[1][q[++qr]=i]=D;
			}d[1][id(bt)]=0;mxd[1]=d[1][q[qr]]+1;
			cal_dc(1);cal_dsm(0,1),cal_dsm(1,0),cal_dsm(1,1);
		} 
	}void cal_dc(int t){
		int md=mxd[t],*dc=dct[t]=new int[md];
		memset(dc,0,sizeof(int)*md);
		for(int re i=1;i<sz;++i)++dc[d[t][i]];
		for(int re i=1;i<md;++i)dc[i]+=dc[i-1];
	}void cal_dsm(int k,int l){
		int md=mxd[k],*d0=d[k],*d1=d[l];
		int *ds=dsm[k][l]=new int[md];
		memset(ds,0,sizeof(int)*md);
		for(int re i=1;i<sz;++i)ds[d0[i]]+=d1[i];
		for(int re i=1;i<md;++i)ds[i]+=ds[i-1];
	}void get(int u,int dl,int *ed,int &ds0,int &ds1,int &dc)cs{
		memset(ed,0,sizeof(int)*sz);
		if(dl>mxd[0]+d[0][u]){
			ds0=dsm[0][0][mxd[0]-1];
			if(bot)ds1=dsm[0][1][mxd[0]-1];
			dc=sz-1;
			for(int re i=sz-1;i;--i)
				ed[fa[i]]+=++ed[i];
			return ;
		}ql=0;ed[q[qr=1]=u]=1;
		for(;ql<qr&&dl>0;--dl){
			for(int re r=qr;ql<r;){
				int u=q[++ql];
				for(int re i=ch[u];i<ch[u+1];++i)
					if(!ed[i])ed[q[++qr]=i]=1;
				if(!ed[fa[u]])ed[q[++qr]=fa[u]]=1; 
			}
		}
		dc=qr-ed[0];
		for(int re i=1;i<sz;++i)
			if(ed[i]){
				ds0+=d[0][i];
				if(bot)ds1+=d[1][i];
			}
		for(int re i=sz-1;i;--i)
			ed[fa[i]]+=ed[i];
	}ll calc(int *a,int *b)cs{
		ll r=0;
		for(int re i=1;i<sz;++i)
			r+=(ll)a[i]*b[i];
		return r;
	}void solve(int,int)cs;
}bl[BN];

void build_block(int tp,int bt){
	if(ql==qr)return ;
	q[0]=tp;++bct;
	bl[bct].build(tp,bt,bct);
	ql=qr=0;
}

void build_dfs(int u){
	int s=0,d=0;
	for(int v:G[u]){
		build_dfs(v);s+=sz[v];
		if(bt[v])++d,bt[u]=bt[v];
	}
	if(s>=B||d>=2||u==0){
		bt[u]=u;s=0,d=0;
		std::sort(G[u].begin(),G[u].end(),
			[](int u,int v){return sz[u]<sz[v];});
		for(int v:G[u]){
			if(s>=B||(d&&bt[v])){
				build_block(u,d);
				s=0,d=0;
			}if(!d)d=bt[v];
			q[++qr]=v,s+=sz[v];
		}build_block(u,d),sz[u]=1;
	}else sz[u]=1+s;
}

namespace BlockTree{

int fa[BN],ch[BN],cr[BN],len[BN];

void build(){
	build_dfs(0);
	for(int re i=1;i<=bct;++i){
		auto &t=bl[i];
		if(t.bot)len[i]=t.d[0][t.id(t.bot)];
		for(int re j=i+1;j<=bct;++j)
			if(bl[j].bot==t.top)
				{++cr[fa[i]=j];break;}
	}
	for(int re i=1;i<=bct+1;++i)cr[i]+=cr[i-1];
	for(int re i=1;i<bct;++i)ch[--cr[fa[i]]]=i;
}
struct node{
	ll ds;int sz;node(){}node(ll a,int b):ds(a),sz(b){}
	void inc(cs node &a,int d){ds+=a.ds+(ll)a.sz*d,sz+=a.sz;}
	node operator-(cs node &rhs)cs{return node(ds-rhs.ds,sz-rhs.sz);}
	ll operator^(cs node &rhs)cs{return ds*rhs.sz+rhs.ds*sz;}
};
struct Sol{
	int ct[BN],ctd[BN];
	node ds0[BN],ds1[BN];
	int b,b1,Ds0,Ds1,dc;
	void init(int p,int d,int _b1,int qid){
		b1=_b1;b=bel[p];
		memset(ds0,0,sizeof(node)*(bct+2));
		memset(ds1,0,sizeof(node)*(bct+2));
		memset(ctd,0,sizeof(int)*bl[b1].sz);
		auto &T=bl[b];Ds0=Ds1=dc=0;p-=T.idl;
		T.get(p,d,ct,Ds0,Ds1,dc);
		ds0[b]={Ds0,dc};if(T.bot)ds1[b]={Ds1,dc};
		T.qry.push_back({T.d[0][p],(T.bot?T.d[1][p]:0),d,qid});
		if(fa[b]){
			int f=fa[b],d1=d-T.d[0][p];
			dfs(f,d1,1);
			for(int re i=cr[f];i<cr[f+1];++i)
				if(ch[i]!=b)dfs(ch[i],d1,0);
		}if(T.bot){
			int d1=d-T.d[1][p];
			for(int re i=cr[b];i<cr[b+1];++i)
				dfs(ch[i],d1,0);
		}
	}
	void dfs(int u,int d,int tp){
		if(d<0)return ;auto &T=bl[u];
		int d1=std::min(T.mxd[tp]-1,d),dc=T.dct[tp][d1];
		ds0[u]={T.dsm[tp][0][d1],dc};
		if(T.bot)ds1[u]={T.dsm[tp][1][d1],dc};
		if(u==b1){
			for(int re i=1;i<T.sz;++i)
				ctd[i]=(T.d[tp][i]<=d);
			for(int re i=T.sz-1;i;--i)
				ctd[T.fa[i]]+=ctd[i];
		}d-=len[u];
		if(tp){
			int f=fa[u];if(!f)return;dfs(f,d,1);
			for(int re i=cr[f];i<cr[f+1];++i)
				if(ch[i]!=u)dfs(ch[i],d,0);
		}else for(int re i=cr[u];i<cr[u+1];++i)
			dfs(ch[i],d,0);
	}
	void calc_all(){
		for(int re u=1;u<bct;++u)
			ds0[fa[u]].inc(ds0[u],len[fa[u]]);
		for(int re u=bct;u;--u){
			for(int re i=cr[u];i<cr[u+1];++i)
				ds1[u].inc(ds0[ch[i]],0);
			for(int re i=cr[u];i<cr[u+1];++i)
				ds1[ch[i]].inc(ds1[u]-ds0[ch[i]],len[ch[i]]);
		}
	}
}t0,t1;
void Query(int p0,int d0,int p1,int d1,int qid){
	t0.init(p0,d0,bel[p1],qid);
	t1.init(p1,d1,bel[p0],qid);
	ll s=0;int b0=bel[p0],b1=bel[p1];
	auto &B0=bl[b0],&B1=bl[b1];
	if(b0==b1){
		s+=(t0.ds0[b0]^t1.ds0[b1])-2*B0.calc(t0.ct,t1.ct);
	}else {
		s+=(t0.ds0[b0]^t1.ds0[b0])-2*B0.calc(t0.ct,t1.ctd);
		s+=(t0.ds0[b1]^t1.ds0[b1])-2*B1.calc(t1.ct,t0.ctd);
	}
	t0.calc_all();
	for(int re u=1;u<=bct;++u){
		for(int re i=cr[u];i<cr[u+1];++i){
			int v=ch[i];
			ans[qid]+=(t0.ds1[u]-t0.ds0[v])^t1.ds0[v];
			ans[qid]+=t0.ds0[v]^t1.ds1[u];
		}
	}ans[qid]+=s;
}
int cur_d,qs[N][2],qp[N];
void dfs(int u,int d,int tp){
	for(auto &q:bl[u].qry){
		int d1=q.d-(tp?q.d1:q.d0)-d;
		if(d1>=0)
			qs[q.id][qp[q.id]++]=d1*2+cur_d;
	}d+=len[u];
	if(tp){
		int f=fa[u];if(f)dfs(f,d,1);
		for(int re i=cr[f];i<cr[f+1];++i)
			if(ch[i]!=u)dfs(ch[i],d,0);
	}else for(int re i=cr[u];i<cr[u+1];++i)
		dfs(ch[i],d,0);
}
void solve(){
	for(int re u=1;u<=bct;++u){
		memset(qs,0,sizeof(int)*(m+2)*2);
		memset(qp,0,sizeof(int)*(m+2));
		if(fa[u]){
			int p=fa[u];cur_d=0;dfs(p,0,1);
			for(int re i=cr[p];i<cr[p+1];++i)
				if(ch[i]!=u)dfs(ch[i],0,0);
		}cur_d=1;
		for(int re i=cr[u];i<cr[u+1];++i)
			dfs(ch[i],0,0);
		auto &T=bl[u];
		T.solve(0,0);
		if(T.bot){
			T.solve(0,1);
			T.solve(1,1);
		}
	}
}

}

void Block::solve(int t0,int t1)cs{
	using BlockTree::qs;
	using BlockTree::qp;
	int md0=mxd[t0],md1=mxd[t1];
	int *d0=d[t0],*d1=d[t1];
	static ll tmp[BN][BN],vl[BN];
	for(int re a=0;a<md0;++a){
		memset(vl,0,sizeof(ll)*sz);
		for(int re i=1;i<sz;++i)
			vl[i]=(d0[i]<=a);
		for(int re i=sz-1;i;--i)
			vl[fa[i]]+=vl[i];
		vl[0]=0;
		for(int re i=1;i<sz;++i)
			vl[i]+=vl[fa[i]];
		memset(tmp[a],0,sizeof(ll)*md1);
		for(int re i=1;i<sz;++i)
			tmp[a][d1[i]]+=vl[i];
		for(int re i=1;i<sz;++i)
			tmp[a][i]+=tmp[a][i-1];
	}
	for(int re i=1;i<=m;++i)if(qp[i]==2){
		int a=qs[i][0],b=qs[i][1];
		if(!(((a^t0)|(b^t1))&1)){
			a=std::min(md0-1,a>>1);
			b=std::min(md1-1,b>>1);
			ans[i]+=
				dct[t0][a]*dsm[t1][0][b]+
				dct[t1][b]*dsm[t0][0][a]-
				2*tmp[a][b];
		}
	}
}

void Main(){
	n=gi();B=sqrt(n)*2;
	G[0].push_back(1);
	for(int i=2;i<=n;++i)
		G[fa[i]=gi()].push_back(i);
	BlockTree::build();m=gi();
	for(int re i=1;i<=m;++i){
		int p0=gi(),d0=gi(),p1=gi(),d1=gi();
		BlockTree::Query(tr[p0],d0,tr[p1],d1,i);
	}BlockTree::solve();
	for(int re i=1;i<=m;++i)
		print(ans[i],'\n');
}

inline void file(){
#ifdef zxyoi
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
#endif
}signed main(){file();Main();return 0;} 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章