【PKUWC2018】Minimax【線段樹合併】

題意:給定一棵nn個點的二叉樹,葉子的權值輸入給定且互不相同,非葉子結點ii的權值有pip_i的概率爲兒子結點權值最大值,1pi1-p_i的概率爲最小值。求根結點取每種值的概率。模998244353998244353

n3×105n\leq 3\times 10^5

這都能線段樹合併……覺了
f(u,x)f(u,x)uu點值爲xx的概率,l,rl,r爲它的左右兒子

容易寫出

f(u,x)=px[f(l,x)i=1x1f(r,i)+f(r,x)i=1x1f(l,i)]+(1px)[f(l,x)i=x+1mf(r,i)+f(r,x)i=x+1mf(l,i)]f(u,x)=p_x[f(l,x)\sum_{i=1}^{x-1}f(r,i)+f(r,x)\sum_{i=1}^{x-1}f(l,i)]+(1-p_x)[f(l,x)\sum_{i=x+1}^mf(r,i)+f(r,x)\sum_{i=x+1}^mf(l,i)]

考慮線段樹合併

設當前合併的區間是[L,R][L,R],在遞歸的時候順便維護兩個線段樹結點[1,L1][1,L-1][R+1,m][R+1,m]的和,乘到f(l,x)f(l,x)f(r,x)f(r,x)上面,維護一個乘法標記。

文字不太好講清楚,建議直接看代碼。

複雜度O(nlogn)O(n\log n)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#define MAXN 300005
using namespace std;
inline int read()
{
	int ans=0;
	char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
	return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{
	int ans=1;
	while (p)
	{
		if (p&1) ans=(ll)ans*a%MOD;
		a=(ll)a*a%MOD;p>>=1;
	}
	return ans;
}
namespace SGT
{
	int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt;
	inline void update(int x){sum[x]=(sum[ch[x][0]]+sum[ch[x][1]])%MOD;}
	inline void pushmul(int x,int v){sum[x]=(ll)sum[x]*v%MOD,mul[x]=(ll)mul[x]*v%MOD;}
	inline void pushdown(int x)
	{
		if (mul[x]!=1)
		{
			pushmul(ch[x][0],mul[x]),pushmul(ch[x][1],mul[x]);
			mul[x]=1;
		}
	}
	inline int newnode(){return ++cnt,sum[cnt]=mul[cnt]=1,cnt;}
	void insert(int& x,int l,int r,int k)
	{
		x=newnode();
		if (l==r) return;
		int mid=(l+r)>>1;
		if (k<=mid) insert(ch[x][0],l,mid,k);
		else insert(ch[x][1],mid+1,r,k);
	}
	int merge(int x,int y,int l,int r,int xmul,int ymul,int v)
	{
		if (!x&&!y) return 0;
		if (!x) return pushmul(y,ymul),y;
		if (!y) return pushmul(x,xmul),x;
		int mid=(l+r)>>1;
		pushdown(x),pushdown(y);
		int xl=sum[ch[x][0]],xr=sum[ch[x][1]],yl=sum[ch[y][0]],yr=sum[ch[y][1]];
		ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,(xmul+(MOD+1ll-v)*yr)%MOD,(ymul+(MOD+1ll-v)*xr)%MOD,v);
		ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,(xmul+(ll)v*yl)%MOD,(ymul+(ll)v*xl)%MOD,v);
		return update(x),x;
	}
	void getans(int x,int l,int r,int* &ans)
	{
		if (l==r) return (void)(*(ans++)=sum[x]);
		pushdown(x);
		int mid=(l+r)>>1;
		getans(ch[x][0],l,mid,ans),getans(ch[x][1],mid+1,r,ans);
	}
}
using SGT::insert;
using SGT::merge;
using SGT::getans;
int rt[MAXN],ch[MAXN][2],p[MAXN],v[MAXN],m;
void dfs(int u)
{
	if (!ch[u][0]) return insert(rt[u],1,m,p[u]);
	dfs(ch[u][0]);
	if (!ch[u][1]) return (void)(rt[u]=rt[ch[u][0]]);
	dfs(ch[u][1]);
	rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,m,0,0,p[u]);
}
int ans[MAXN];
int main()
{
	int n=read();
	for (int i=1;i<=n;i++)
	{
		int f=read();
		if (!f) continue;
		if (!ch[f][0]) ch[f][0]=i;
		else ch[f][1]=i;
	}
	int t=qpow(10000,MOD-2);
	for (int i=1;i<=n;i++)
	{
		p[i]=read();
		if (ch[i][0]) p[i]=(ll)p[i]*t%MOD;
		else v[++m]=p[i];
	}
	sort(v+1,v+m+1);
	for (int i=1;i<=n;i++)
		if (!ch[i][0])
			p[i]=lower_bound(v+1,v+m+1,p[i])-v;
	dfs(1);
	int* p=ans+1;
	getans(rt[1],1,m,p);
	int res=0;
	for (int i=1;i<=m;i++) res=(res+(ll)i*v[i]%MOD*ans[i]%MOD*ans[i])%MOD;
	printf("%d\n",res);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章