題解:[GXOI/GZOI2019]舊詞

這個題目其實早就做了,只是突然發現還沒發,那就湊一下GZOI

題意:給定$x,y$求

$$\sum_{i\leq x}dep(lca(i,y))^k$$

首先我們先來看這個題目的簡化版 

https://www.luogu.org/problem/P4211

求  $$\sum_{i\leq x}dep(lca(i,y))$$

我們來看$dep$的實際意義——從 i 點到根有多少個點(包括 i )。

我們從整體上考慮,發現對於一個詢問:所有的 $lca$ 都在 $y$ 到根的路徑上。從而有一些點,它們對很多的 $lca$ 的深度都有貢獻,而這個貢獻等於在這個點下面的 $lca$ 的個數,所以我們可以把每個 $lca$ 到根的路徑上的每個點的權值都加一。然後從 $y$ 向上走到根,沿路統計的權值就是答案了。

這裏,我們可以把所有的詢問離線下來,按照 $x$  排序,然後每個節點就向上跳把所有的上面的點染上顏色,然後查詢的時候只需要向上找,看有多少染上顏色的節點,並且計算貢獻,這裏我們只需要用樹鏈剖分維護一下就行了

然後是我非常醜的代碼

#include <bits/stdc++.h>
using namespace std;

#define re register
#define ll long long
#define gc getchar()
inline ll read()
{
 	re ll x(0),f(1);re char c(gc);
    while(c>'9'||c<'0')f=c=='-'?-1:1,c=gc;
    while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
    return f*x;
}

const ll N=50500,mod=201314;
ll n,Q,k,h[N],cnt,qs;
struct edge{ll next,to;}e[N];

void add(ll u,ll v){e[++cnt]=(edge){h[u],v},h[u]=cnt;}
#define QXX(u) for(ll i=h[u],v;v=e[i].to,i;i=e[i].next)

ll dep[N],fa[N],son[N],siz[N],top[N],rev[N],seq[N],tot;

void dfs(ll u)
{
	siz[u]=1;
	QXX(u)
	{
		dfs(v);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void Dfs(ll u)
{
	if(son[u])
	{
		ll v=son[u];
		seq[v]=++tot;
		rev[tot]=v;
		top[v]=top[u];
		Dfs(v);
	}
	QXX(u)
	{
		if(v==son[u]) continue;
		seq[v]=++tot;
		rev[tot]=v;
		top[v]=v;
		Dfs(v);
	}
}

struct node{ll id,x,z,ans;bool w;}q[N<<1];
bool operator < (node a,node b){return a.x<b.x;}

#define ls id<<1
#define rs id<<1|1
#define mid ((l+r)>>1)

ll sum[N<<2],tag[N<<2];
void pushup(ll id){sum[id]=(sum[ls]+sum[rs])%mod;}

void pushdown(ll id,ll l,ll r)
{
	if(tag[id])
	{
		tag[ls]+=tag[id];
		tag[rs]+=tag[id];
		sum[ls]+=tag[id]*(mid-l+1);
		sum[rs]+=tag[id]*(r-mid);
		sum[ls]%=mod;
		sum[rs]%=mod;
		tag[id]=0;
	}
}

void change(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R)
	{
		tag[id]++;
		sum[id]+=(r-l+1);
		sum[id]%=mod;
		return;
	}
	pushdown(id,l,r);
	if(mid>=L) change(ls,l,mid,L,R);
	if(mid<R) change(rs,mid+1,r,L,R);
	pushup(id);
}

void work(ll x)
{
	while(1)
	{
		if(top[x]!=x)
			change(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			change(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		if(x==0) return;
	}
}

ll query(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R) return sum[id]%mod;
	pushdown(id,l,r);
	ll ans=0;
	if(mid>=L) ans+=query(ls,l,mid,L,R);
	if(mid<R) ans+=query(rs,mid+1,r,L,R);
	return ans%mod;
}

ll ask(ll x)
{
	ll ans=0;
	while(1)
	{
		if(top[x]!=x)
			ans+=query(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			ans+=query(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		ans%=mod;
		if(x==0) return ans;
	}
}

ll ans[N];

int main()
{
	n=read(),Q=read();
	for(ll i=2;i<=n;++i)
	{
		ll x=read()+1;
		add(x,i);
		fa[i]=x;
	}
	top[1]=1;seq[1]=++tot;
	dfs(1);Dfs(1);
	for(ll i=1;i<=Q;++i)
	{
		ll l=read(),r=read()+1,z=read()+1;
		q[++qs]=(node){i,l,z,0,0};
		q[++qs]=(node){i,r,z,0,1};
	}
	sort(q+1,qs+1+q);
	ll t=0;
	while(q[t+1].x<1) ++t;
	for(ll i=1;i<=n;++i)
	{
		work(i);
		while(q[t+1].x<=i&&t<qs)
			++t,q[t].ans=ask(q[t].z);
		if(t==qs) break;
	}
	for(ll i=1;i<=qs;++i)
	{
		if(q[i].w==0) ans[q[i].id]-=q[i].ans;
		else ans[q[i].id]+=q[i].ans;
	}
	for(ll i=1;i<=Q;++i)
		cout<<(ans[i]+mod)%mod<<endl;
	return 0;
}  

然後我們回到本題,這裏是多了一個 $k$ 次方

首先我們來看前面的每次$+1$是哪裏來的 $dep[i]->dep[i+1]$所以這裏實際上就是在做差分,那麼我們把指數換成 $k$

$dep[i]^k->(dep[i]+1)^k$

那麼,我們就預處理出來每一個 $dep^k$ 然後對於每個節點就相當於每次會增加 $dep[x]^k-(dep[x]-1)^$ 的貢獻

然後我們就可以轉化爲,對於一個序列,每個點的值是 $a*b$ 其中 $b$ 是定值,但是每個節點不一樣,每次操作就是做區間修改給 $a$ 加上1和區間查詢

然後我們維護線段樹的時候再多維護一個 $sum_b$ 就可以了

 

#include <bits/stdc++.h>
using namespace std;

#define re register
#define ll long long
#define gc getchar()
inline ll read()
{
 	re ll x(0),f(1);re char c(gc);
    while(c>'9'||c<'0')f=c=='-'?-1:1,c=gc;
    while(c>='0'&&c<='9')x=x*10+c-48,c=gc;
    return f*x;
}

const ll N=50050,mod=998244353;

ll n,Q,k,h[N],cnt,qs;
struct edge{ll next,to;}e[N];

void add(ll u,ll v){e[++cnt]=(edge){h[u],v},h[u]=cnt;}
#define QXX(u) for(ll i=h[u],v;v=e[i].to,i;i=e[i].next)

ll dep[N],fa[N],son[N],siz[N],top[N],seq[N],rev[N],tot;

void dfs(ll u)
{
	siz[u]=1;
	QXX(u)
	{
		dep[v]=dep[u]+1;
		dfs(v);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v])
			son[u]=v;
	}
}
void Dfs(ll u,ll to)
{
	seq[u]=++tot,rev[tot]=u;
	top[u]=to;
	if(son[u])
		Dfs(son[u],top[u]);
	QXX(u)
	{
		if(v==son[u]) continue;
		Dfs(v,v);
	}
}

ll qpow(ll x,ll b)
{
	ll a=1;
	while(b)
	{
		if(b&1) a=(x*a)%mod;
		x=(x*x)%mod,b>>=1;
	}
	return a;
}

#define ls id<<1
#define rs id<<1|1
#define mid ((l+r)>>1)

ll su[N<<2],sum[N<<2],tag[N<<2],po[N];

void pushup(ll id)
{
	sum[id]=sum[ls]+sum[rs];
	su[id]=su[ls]+su[rs];
}
void pushdown(ll id,ll l,ll r)
{
	if(tag[id])
	{
		tag[ls]+=tag[id];
		tag[rs]+=tag[id];
		sum[ls]=(sum[ls]+tag[id]*su[ls])%mod;
		sum[rs]=(sum[rs]+tag[id]*su[rs])%mod;
		tag[id]=0;
	}
}
void built(ll id,ll l,ll r)
{
	if(l==r)
	{
		su[id]=(po[dep[rev[l]]]+mod-po[dep[rev[l]]-1])%mod;
		return;
	}
	built(ls,l,mid);
	built(rs,mid+1,r);
	pushup(id);
}
void change(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R)
	{
		tag[id]++;
		sum[id]+=su[id];
		sum[id]%=mod;
		return;
	}
	pushdown(id,l,r);
	if(mid>=L) change(ls,l,mid,L,R);
	if(mid<R) change(rs,mid+1,r,L,R);
	pushup(id);
}
ll query(ll id,ll l,ll r,ll L,ll R)
{
	if(l>=L&&r<=R) return sum[id]%mod;
	pushdown(id,l,r);
	ll ans=0;
	if(mid>=L) ans+=query(ls,l,mid,L,R);
	if(mid<R) ans+=query(rs,mid+1,r,L,R);
	return ans%mod;
}
void work(ll x)
{
	while(1)
	{
		if(top[x]!=x)
			change(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			change(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		if(x==0) return;
	}
}
ll ask(ll x)
{
	ll ans=0;
	while(1)
	{
		if(top[x]!=x)
			ans+=query(1,1,tot,seq[top[x]],seq[x]),x=fa[top[x]];
		else
		{
			ans+=query(1,1,tot,seq[x],seq[x]);
			x=fa[x];
		}
		ans%=mod;
		if(x==0) return ans;
	}
}
struct node{ll id,x,y,ans;}q[N];
bool cmpx(node a,node b){return a.x<b.x;}
bool cmpi(node a,node b){return a.id<b.id;}

int main()
{
	n=read(),Q=read(),k=read();
	for(ll i=1;i<=n;++i)
		po[i]=qpow(i,k);
	for(ll i=2;i<=n;++i)
	{
		fa[i]=read();
		add(fa[i],i);
	}
	dep[1]=1;
	dfs(1),Dfs(1,1);
	built(1,1,tot);
	for(ll i=1;i<=Q;++i)
	{
		ll x=read(),y=read();
		q[i]=(node){i,x,y,0};
	}
	sort(q+1,q+1+Q,cmpx);
	ll t=0;
	while(q[t+1].x<1) ++t;
	for(ll i=1;i<=n;++i)
	{
		work(i);
		while(q[t+1].x<=i&&t<Q)
			++t,q[t].ans=ask(q[t].y);
		if(t==Q) break;
	}
	sort(q+1,q+1+Q,cmpi);
	for(ll i=1;i<=Q;++i)
		cout<<q[i].ans<<endl;
	return 0;
}

  

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章