【PKUWC2018】Minimax【线段树合并】

题意:给定一棵nn个点的二叉树,叶子的权值输入给定且互不相同,非叶子结点ii的权值有pip_i的概率为儿子结点权值最大值,1pi1-p_i的概率为最小值。求根结点取每种值的概率。模998244353998244353

n3×105n\leq 3\times 10^5

这都能线段树合并……觉了
f(u,x)f(u,x)uu点值为xx的概率,l,rl,r为它的左右儿子

容易写出

f(u,x)=px[f(l,x)i=1x1f(r,i)+f(r,x)i=1x1f(l,i)]+(1px)[f(l,x)i=x+1mf(r,i)+f(r,x)i=x+1mf(l,i)]f(u,x)=p_x[f(l,x)\sum_{i=1}^{x-1}f(r,i)+f(r,x)\sum_{i=1}^{x-1}f(l,i)]+(1-p_x)[f(l,x)\sum_{i=x+1}^mf(r,i)+f(r,x)\sum_{i=x+1}^mf(l,i)]

考虑线段树合并

设当前合并的区间是[L,R][L,R],在递归的时候顺便维护两个线段树结点[1,L1][1,L-1][R+1,m][R+1,m]的和,乘到f(l,x)f(l,x)f(r,x)f(r,x)上面,维护一个乘法标记。

文字不太好讲清楚,建议直接看代码。

复杂度O(nlogn)O(n\log n)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#define MAXN 300005
using namespace std;
inline int read()
{
	int ans=0;
	char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
	return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{
	int ans=1;
	while (p)
	{
		if (p&1) ans=(ll)ans*a%MOD;
		a=(ll)a*a%MOD;p>>=1;
	}
	return ans;
}
namespace SGT
{
	int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt;
	inline void update(int x){sum[x]=(sum[ch[x][0]]+sum[ch[x][1]])%MOD;}
	inline void pushmul(int x,int v){sum[x]=(ll)sum[x]*v%MOD,mul[x]=(ll)mul[x]*v%MOD;}
	inline void pushdown(int x)
	{
		if (mul[x]!=1)
		{
			pushmul(ch[x][0],mul[x]),pushmul(ch[x][1],mul[x]);
			mul[x]=1;
		}
	}
	inline int newnode(){return ++cnt,sum[cnt]=mul[cnt]=1,cnt;}
	void insert(int& x,int l,int r,int k)
	{
		x=newnode();
		if (l==r) return;
		int mid=(l+r)>>1;
		if (k<=mid) insert(ch[x][0],l,mid,k);
		else insert(ch[x][1],mid+1,r,k);
	}
	int merge(int x,int y,int l,int r,int xmul,int ymul,int v)
	{
		if (!x&&!y) return 0;
		if (!x) return pushmul(y,ymul),y;
		if (!y) return pushmul(x,xmul),x;
		int mid=(l+r)>>1;
		pushdown(x),pushdown(y);
		int xl=sum[ch[x][0]],xr=sum[ch[x][1]],yl=sum[ch[y][0]],yr=sum[ch[y][1]];
		ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,(xmul+(MOD+1ll-v)*yr)%MOD,(ymul+(MOD+1ll-v)*xr)%MOD,v);
		ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,(xmul+(ll)v*yl)%MOD,(ymul+(ll)v*xl)%MOD,v);
		return update(x),x;
	}
	void getans(int x,int l,int r,int* &ans)
	{
		if (l==r) return (void)(*(ans++)=sum[x]);
		pushdown(x);
		int mid=(l+r)>>1;
		getans(ch[x][0],l,mid,ans),getans(ch[x][1],mid+1,r,ans);
	}
}
using SGT::insert;
using SGT::merge;
using SGT::getans;
int rt[MAXN],ch[MAXN][2],p[MAXN],v[MAXN],m;
void dfs(int u)
{
	if (!ch[u][0]) return insert(rt[u],1,m,p[u]);
	dfs(ch[u][0]);
	if (!ch[u][1]) return (void)(rt[u]=rt[ch[u][0]]);
	dfs(ch[u][1]);
	rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,m,0,0,p[u]);
}
int ans[MAXN];
int main()
{
	int n=read();
	for (int i=1;i<=n;i++)
	{
		int f=read();
		if (!f) continue;
		if (!ch[f][0]) ch[f][0]=i;
		else ch[f][1]=i;
	}
	int t=qpow(10000,MOD-2);
	for (int i=1;i<=n;i++)
	{
		p[i]=read();
		if (ch[i][0]) p[i]=(ll)p[i]*t%MOD;
		else v[++m]=p[i];
	}
	sort(v+1,v+m+1);
	for (int i=1;i<=n;i++)
		if (!ch[i][0])
			p[i]=lower_bound(v+1,v+m+1,p[i])-v;
	dfs(1);
	int* p=ans+1;
	getans(rt[1],1,m,p);
	int res=0;
	for (int i=1;i<=m;i++) res=(res+(ll)i*v[i]%MOD*ans[i]%MOD*ans[i])%MOD;
	printf("%d\n",res);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章