[WC2019]數樹

Description

有兩棵n個點的樹,T1和T2,每個點可以填一個[1,y]的顏色
若兩棵樹有一條邊重合這條邊兩個端點的顏色必須相同
有三種問題:
op=0,給出T1和T2問答案
op=1,給出T1問所有T2的答案和
op=2,給出n問所有的T1和T2的答案和
n<=100000

Solution

op=0大家都會
op=1的話,考慮算恰好有i條重邊的yiy^{-i}的和
考慮到zm=i=0m(z1)i(im)z^m=\sum_{i=0}^{m}(z-1)^i(^m_i)
於是我們可以變成,枚舉i條重邊,其餘邊隨意的(z-1)^i的和
i條邊會把原樹分成n-i個連通塊,設ai爲連通塊大小,根據purfer序我們可以推出生成樹個數爲nni2ain^{n-i-2}\prod ai
考慮到ai\prod ai的組合意義爲從每個連通塊裏選一個點的方案數
設F[x][0…1]表示當前x爲根的連通塊有沒有選這個點
直接Dp是O(n)的
op=2根據op=1斷掉每條重邊求出一個連通塊的EGF然後exp一下就好了

Code

#include <map>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a) for(int i=lst[a];i;i=nxt[i])
using namespace std;

typedef long long ll;

int read() {
	char ch;
	for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
	int x=ch-'0';
	for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
	return x;
}

const int N=1e6+5,Mo=998244353;

int pwr(int x,int y) {
	int z=1;
	for(;y;y>>=1,x=(ll)x*x%Mo)
		if (y&1) z=(ll)z*x%Mo;
	return z;
}

map<int,int> e[N];

int n,y,z;

void solve_0() {
	fo(i,1,n-1) {
		int x=read(),y=read();
		e[x][y]=e[y][x]=1;
	}
	int cnt=0;
	fo(i,1,n-1) {
		int x=read(),y=read();
		cnt+=e[x][y];
	}
	printf("%d\n",pwr(y,n-cnt));
}

int t[N],nxt[N],lst[N],l;
void add(int x,int y) {t[++l]=y;nxt[l]=lst[x];lst[x]=l;}

int dp[N][2];

void Dp(int x,int y) {
	rep(i,x) if (t[i]!=y) Dp(t[i],x);
	dp[x][0]=dp[x][1]=1;int s0,s1;
	rep(i,x)
		if (t[i]!=y) {
			// 選
			s0=(ll)dp[x][0]*dp[t[i]][1]%Mo*n%Mo;
			s1=(ll)dp[x][1]*dp[t[i]][1]%Mo*n%Mo;
			// 不選
			(s0+=(ll)dp[x][0]*dp[t[i]][0]%Mo*z%Mo)%=Mo;
			(s1+=(ll)dp[x][1]*dp[t[i]][0]%Mo*z%Mo)%=Mo;
			(s1+=(ll)dp[x][0]*dp[t[i]][1]%Mo*z%Mo)%=Mo;
			dp[x][0]=s0;dp[x][1]=s1;
		}
}

void solve_1() {
	fo(i,1,n-1) {
		int x=read(),y=read();
		add(x,y);add(y,x);
	}
	if (y==1) {
		printf("%d\n",pwr(n,n-2));
		return;
	}
	z=pwr(y,Mo-2);z--;
	Dp(1,0);
	int ans=(ll)dp[1][1]*pwr(n,Mo-2)%Mo;
	ans=(ll)ans*pwr(y,n)%Mo;
	printf("%d\n",ans);
}

ll W[2][N];

void init(int b) {
	for(int i=1;i<(1<<b);i<<=1){
		ll wn=pwr(3,(Mo-1)/(i<<1));
		for(int j=0;j<i;++j) W[1][i+j]=(j?wn*W[1][i+j-1]%Mo:1);
		wn=pwr(3,Mo-1-(Mo-1)/(i<<1));
		for(int j=0;j<i;++j) W[0][i+j]=(j?wn*W[0][i+j-1]%Mo:1);
	}
}

void DFT(ll *a,int len,int flag) {
	if (flag==-1) flag=0;
	for(int i=0,j=0;i<len;++i){
		if (i<j) swap(a[i],a[j]);
		for(int k=len>>1;(j^=k)<k;k>>=1);
	}
	for(int i=1;i<len;i<<=1)
		for(int j=0;j<len;j+=(i<<1))
			for(int k=0;k<i;++k) {
				ll x=a[j+k],y=a[j+k+i]*W[flag][i+k]%Mo;
				a[j+k]=(x+y)%Mo;
				a[j+k+i]=(x-y)%Mo;
			}
	ll inv=pwr(len,Mo-2);
	if (!flag) for(int i=0;i<len;i++) a[i]=a[i]*inv%Mo;
}

ll c[N];

void get_Inv(ll *a,ll *b,int n) {
	if (n==1) {b[0]=pwr(a[0],Mo-2);return;}
	get_Inv(a,b,n>>1);
	int len=n<<1;
	fo(i,0,n-1) c[i]=a[i];fo(i,n,len-1) c[i]=0;
	fo(i,(n>>1),len-1) b[i]=0;
	DFT(c,len,1);DFT(b,len,1);
	fo(i,0,len-1) b[i]=(2*b[i]-b[i]*b[i]%Mo*c[i])%Mo;
	DFT(b,len,-1);
	fo(i,n,len-1) b[i]=0;
}

ll f[N],g[N];

void get_ln(ll *a,ll *b,int n) {
	fo(i,0,n-2) f[i]=a[i+1]*(i+1)%Mo;f[n-1]=0;
	get_Inv(a,g,n);
	int len=n<<1;
	fo(i,n,len-1) f[i]=g[i]=0;
	DFT(f,len,1);DFT(g,len,1);
	fo(i,0,len-1) f[i]=f[i]*g[i]%Mo;
	DFT(f,len,-1);
	fo(i,1,n-1) b[i]=f[i-1]*pwr(i,Mo-2)%Mo;
	b[0]=0;
}

ll h[N];

void get_exp(ll *a,ll *b,int n) {
	if (n==1) {b[0]=1;return;}
	get_exp(a,b,n>>1);
	get_ln(b,h,n);
	fo(i,0,n-1) h[i]=(a[i]-h[i]+Mo)%Mo;
	(h[0]=h[0]+1)%=Mo;
	int len=n<<1;
	fo(i,n,len-1) h[i]=b[i]=0;
	DFT(h,len,1);DFT(b,len,1);
	fo(i,0,len-1) b[i]=b[i]*h[i]%Mo;
	DFT(b,len,-1);
	fo(i,n,len-1) b[i]=0;
}

ll F[N],G[N],fac[N],inv[N];

void solve_2() {
	init(18);
	if (y==1) {
		printf("%d\n",pwr(pwr(n,n-2),2));
		return;
	}
	z=pwr(y,Mo-2);int iz=pwr(z-1,Mo-2);
	fac[0]=1;fo(i,1,n) fac[i]=fac[i-1]*i%Mo;
	inv[n]=pwr(fac[n],Mo-2);fd(i,n-1,0) inv[i]=(ll)inv[i+1]*(i+1)%Mo;
	fo(i,1,n) F[i]=(ll)n*n%Mo*iz%Mo*pwr(i,i)%Mo*inv[i]%Mo;
	int len=1;for(;len<=n;len<<=1);get_exp(F,G,len);
	int ans=(ll)G[n]*fac[n]%Mo;
	ans=(ll)ans*pwr(pwr(n,Mo-2),4)%Mo*pwr(z-1,n)%Mo;
	ans=(ll)ans*pwr(y,n)%Mo;
	printf("%d\n",(ans+Mo)%Mo);
}

int main() {
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	n=read();y=read();int op=read();
	if (op==0) solve_0();
	if (op==1) solve_1();
	if (op==2) solve_2();
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章