題意:給定一棵個點的二叉樹,葉子的權值輸入給定且互不相同,非葉子結點的權值有的概率爲兒子結點權值最大值,的概率爲最小值。求根結點取每種值的概率。模。
這都能線段樹合併……覺了
設爲點值爲的概率,爲它的左右兒子
容易寫出
考慮線段樹合併
設當前合併的區間是,在遞歸的時候順便維護兩個線段樹結點和的和,乘到和上面,維護一個乘法標記。
文字不太好講清楚,建議直接看代碼。
複雜度
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
#define MAXN 300005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD;p>>=1;
}
return ans;
}
namespace SGT
{
int ch[MAXN<<5][2],sum[MAXN<<5],mul[MAXN<<5],cnt;
inline void update(int x){sum[x]=(sum[ch[x][0]]+sum[ch[x][1]])%MOD;}
inline void pushmul(int x,int v){sum[x]=(ll)sum[x]*v%MOD,mul[x]=(ll)mul[x]*v%MOD;}
inline void pushdown(int x)
{
if (mul[x]!=1)
{
pushmul(ch[x][0],mul[x]),pushmul(ch[x][1],mul[x]);
mul[x]=1;
}
}
inline int newnode(){return ++cnt,sum[cnt]=mul[cnt]=1,cnt;}
void insert(int& x,int l,int r,int k)
{
x=newnode();
if (l==r) return;
int mid=(l+r)>>1;
if (k<=mid) insert(ch[x][0],l,mid,k);
else insert(ch[x][1],mid+1,r,k);
}
int merge(int x,int y,int l,int r,int xmul,int ymul,int v)
{
if (!x&&!y) return 0;
if (!x) return pushmul(y,ymul),y;
if (!y) return pushmul(x,xmul),x;
int mid=(l+r)>>1;
pushdown(x),pushdown(y);
int xl=sum[ch[x][0]],xr=sum[ch[x][1]],yl=sum[ch[y][0]],yr=sum[ch[y][1]];
ch[x][0]=merge(ch[x][0],ch[y][0],l,mid,(xmul+(MOD+1ll-v)*yr)%MOD,(ymul+(MOD+1ll-v)*xr)%MOD,v);
ch[x][1]=merge(ch[x][1],ch[y][1],mid+1,r,(xmul+(ll)v*yl)%MOD,(ymul+(ll)v*xl)%MOD,v);
return update(x),x;
}
void getans(int x,int l,int r,int* &ans)
{
if (l==r) return (void)(*(ans++)=sum[x]);
pushdown(x);
int mid=(l+r)>>1;
getans(ch[x][0],l,mid,ans),getans(ch[x][1],mid+1,r,ans);
}
}
using SGT::insert;
using SGT::merge;
using SGT::getans;
int rt[MAXN],ch[MAXN][2],p[MAXN],v[MAXN],m;
void dfs(int u)
{
if (!ch[u][0]) return insert(rt[u],1,m,p[u]);
dfs(ch[u][0]);
if (!ch[u][1]) return (void)(rt[u]=rt[ch[u][0]]);
dfs(ch[u][1]);
rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,m,0,0,p[u]);
}
int ans[MAXN];
int main()
{
int n=read();
for (int i=1;i<=n;i++)
{
int f=read();
if (!f) continue;
if (!ch[f][0]) ch[f][0]=i;
else ch[f][1]=i;
}
int t=qpow(10000,MOD-2);
for (int i=1;i<=n;i++)
{
p[i]=read();
if (ch[i][0]) p[i]=(ll)p[i]*t%MOD;
else v[++m]=p[i];
}
sort(v+1,v+m+1);
for (int i=1;i<=n;i++)
if (!ch[i][0])
p[i]=lower_bound(v+1,v+m+1,p[i])-v;
dfs(1);
int* p=ans+1;
getans(rt[1],1,m,p);
int res=0;
for (int i=1;i<=m;i++) res=(res+(ll)i*v[i]%MOD*ans[i]%MOD*ans[i])%MOD;
printf("%d\n",res);
return 0;
}