FZOJ192. 「2019冬令營提高組」喫(點分治,NTT,概率與期望)

題目大意:
nn 個點的樹(基環樹)
每次隨機選取一個點進行操作,刪去該點並將答案加上這個點屬於的聯通塊的大小
問答案的期望
n105,n1mnn\le 10^5,n-1\le m\le n


先考慮樹怎麼做,顯然刪去點 xxyyxx 聯通的概率爲 1dis(x,y)\frac{1}{dis(x,y)} ,其中,dis(x,y)dis(x,y)表示 xxyy 的路徑上的點的個數

顯然可以點分治+NTT+NTT

考慮基環樹,如果路徑不經過環顯然不影響。

考慮經過環的路徑,設(ab)(a\rightarrow b)的路徑交環於 c,dc,d
z=dis(a,c)+dis(b,d)z=dis(a,c)+dis(b,d) ,x,yx,y 分別爲 c,dc,d 之間的兩條路徑的長度
那麼容斥一下就可以得到概率 1x+z+1y+z1x+y+z\frac{1}{x+z}+\frac{1}{y+z}-\frac{1}{x+y+z}

可以發現 x+zx+zy+zy+z 分別爲兩個不同的 dis(a,b)dis(a,b)
考慮刪去一掉邊再點分治+NTT+NTT,漏掉的就是過環的另一個 dis(a,b)dis(a,b) 和第三項

漏掉的我們可以把環上所有子樹根據距刪去的邊進行差分後直接NTTNTT,此時子樹內會算重,我們再對子樹點分治+NTT+NTT

第三項只要把所有子樹的信息合起來做一次NTTNTT再減去一個子樹的貢獻即可

代碼:

#include<bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = j;i <= k;++i)
#define repp(i,j,k) for(int i = j;i >= k;--i)
#define ll long long
#define file(x) memset(x,0,sizeof(x))
#define pb push_back
#define SZ(x) ((int)(x.size()))
namespace io {
	const int SIZE = (1 << 21) + 1;
	char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
	// getchar
	#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
	// print the remaining part
	inline void flush () {
		fwrite (obuf, 1, oS - obuf, stdout);
		oS = obuf;
	}
	// putchar
	inline void putc (char x) {
		*oS ++ = x;
		if (oS == oT) flush ();
	}
	// input a signed integer
	template <class I>
	inline void gi (I &x) {
		for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;
		for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;
	}
	// print a signed integer
	template <class I>
	inline void print (I &x) {
		if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;
		while (x) qu[++ qr] = x % 10 + '0',  x /= 10;
		while (qr) putc (qu[qr --]);
	}
	//no need to call flush at the end manually!
	struct Flusher_ {~Flusher_(){flush();}}io_flusher_;
}
using io :: gi;
using io :: putc;
using io :: print;
const int p = 998244353;
inline int calc(int a,int b){return (a+b)%p;}
inline int del(int a,int b){return (a-b+p)%p;}
inline int mul(int a,int b){return 1ll*a*b%p;}
inline int max(int a,int b){return a>b?a:b;}
inline int ksm(int a,int x){int now = 1;for(;x;x>>=1,a=1ll*a*a%p) if(x&1) now = 1ll*now*a%p; return now;}
int n,m,u,v,tp;
int linkk[101000],t,du[101000],son[101000];
struct node{int n,y;}e[201000];
int inv[1001000],g[101000];
int tmp[101000],tot,dep[101000];
int A[801000],B[801000];
int mx_dep[101000];
vector<int>in[101000];
namespace NTT{
    const int G = 3;
	int r[801000],w[801000];
	void ntt(int *a,int f,int flen){
		w[0] = 1; rep(i,0,flen-1) r[i] = (r[i>>1]>>1) | ((i&1)?flen/2:0);
		rep(i,0,flen-1) if(i < r[i]) swap(a[i],a[r[i]]);
		for(int len = 2;len <= flen;len <<= 1){
			int wn = ksm(G,(p-1)/len); if(f == -1) wn = ksm(wn,p-2);
			rep(i,1,len-1) w[i] = mul(w[i-1],wn);
			for(int st = 0;st < flen;st += len)
			    rep(i,0,(len>>1)-1){
			    	int x = a[st+i],y = mul(a[st+(len>>1)+i],w[i]);
			    	a[st+i] = calc(x,y);    a[st+(len>>1)+i] = del(x,y);
				}
		}
		if(f == -1){int inv = ksm(flen,p-2);rep(i,0,flen-1) a[i] = mul(a[i],inv);}
	}
	void Mul(int *a,int n,int *b,int m){
		int flen = 1; while(flen < n+m-1) flen <<= 1;
		rep(i,n,flen-1) a[i] = 0; rep(i,m,flen-1) b[i] = 0;
		ntt(a,1,flen); ntt(b,1,flen); rep(i,0,flen-1) a[i] = mul(a[i],b[i]);
		ntt(a,-1,flen);
	}
	void Mul2(int *a,int n){
		int flen = 1; while(flen < 2*n-1) flen <<= 1;
		rep(i,n,flen-1) a[i] = 0; ntt(a,1,flen); rep(i,0,flen-1) a[i] = mul(a[i],a[i]); ntt(a,-1,flen);
	}
}using namespace NTT;
namespace cir{
	queue<int>q;
	bool vis[101000];
	void find_circle(){
		rep(i,1,n) if(du[i] == 1) vis[i] = true,q.push(i);
		while(!q.empty()){
			int x = q.front();q.pop();
			for(int i = linkk[x];i;i = e[i].n) if(!vis[e[i].y]){ du[e[i].y]--;if(du[e[i].y] == 1) q.push(e[i].y),vis[e[i].y] = true; }
		}
		rep(i,1,n) if(!vis[i]) {tmp[++tot] = i;break;}
		int x = tmp[tot]; for(int i = linkk[x];i;i = e[i].n) if(!vis[e[i].y]) {x = e[i].y;break;}
		while(x != tmp[1]) {tmp[++tot] = x;for(int i = linkk[x];i;i = e[i].n) if(!vis[e[i].y] && e[i].y != tmp[tot-1]){x = e[i].y;break;}}
		u = tmp[1];v = tmp[tot];
	}
}using namespace cir;
namespace work{
	int f[101000],sz[101000],rt,now_size;
	int ans[101000],a[401000],b[401000],mx,tot_mx;
	bool inq[101000],ok;
	void get_rt(int x,int fa){
		f[x] = 0;sz[x] = 1;
		for(int i = linkk[x];i;i = e[i].n) if(!inq[e[i].y] && e[i].y != fa) get_rt(e[i].y,x),f[x] = max(f[x],sz[e[i].y]),sz[x] += sz[e[i].y];
		f[x] = max(f[x],now_size-sz[x]); if(f[x] < f[rt]) rt = x;
	}
	void get_dis(int x,int fa,int v){a[v]++; mx=max(mx,v); for(int i = linkk[x];i;i = e[i].n) if(!inq[e[i].y] && e[i].y != fa) get_dis(e[i].y,x,v+1);}
	void tr_calc(int x,int v,int f){
		mx = 0;get_dis(x,0,f); 
		NTT::Mul2(a,mx+1); 
		rep(i,0,2*mx) ans[i] = calc(ans[i],mul(v,a[i])); 
		rep(i,0,2*mx) a[i] = 0; tot_mx = max(tot_mx,2*mx);
	}
	void solve(int x){
		inq[x] = true; tr_calc(rt,1,0);
		for(int i = linkk[x];i;i = e[i].n) if(!inq[e[i].y]) tr_calc(e[i].y,-1,1) , now_size = sz[e[i].y] , rt = 0 , get_rt(e[i].y,x) , solve(rt);
	}
	void doit(int x,int sz){
		if(x == 1 && sz == 2) ok = true;
		rt = 0;now_size = sz;f[0] = 2*n;
		get_rt(x,0);solve(rt);
		if(x == 1 && sz == 2) ok = false;
	}
	void again(){
		rep(i,0,tot_mx) ans[i] = 0;
		tot_mx = 0;
	}
}using namespace work;
void insert(int x,int y){
	e[++t].y = y;e[t].n = linkk[x];linkk[x] = t;du[x]++;
	e[++t].y = x;e[t].n = linkk[y];linkk[y] = t;du[y]++;
}
void dele(int x,int y){if(e[linkk[x]].y == y) {linkk[x] = e[linkk[x]].n;return;}for(int i = linkk[x];i;i = e[i].n) if(e[e[i].n].y == y) {e[i].n = e[e[i].n].n;return;}}
void get_sz(int x,int fa){
    dep[x] = dep[fa] + 1; if(dep[x] > mx_dep[tp]) mx_dep[tp]++,in[tp].pb(0); in[tp][dep[x]]++; sz[x] = 1;
	for(int i = linkk[x];i;i = e[i].n) if(!inq[e[i].y] && e[i].y != fa) get_sz(e[i].y,x),sz[x] += sz[e[i].y];
}
int fuck;
int f1(){
	int now = 0,n = 0,m = 0;
	rep(i,1,tot) {rep(j,0,SZ(in[i])-1) A[j+i-1] = calc(A[i+j-1],in[i][j]);n = max(n,i+SZ(in[i])-1);}
	rep(i,1,tot) {rep(j,0,SZ(in[i])-1) B[j+tot-i] = calc(B[j+tot-i],in[i][j]);m = max(m,tot-i+SZ(in[i]));}
	NTT::Mul(A,n,B,m);
	rep(i,0,n+m-2) now = calc(now,mul(inv[i+2],A[i])),A[i] = 0;
	return now;
}
int f2(){
	int now = 0,n = 0;
	rep(i,1,tot){
		rep(j,0,SZ(in[i])-1) A[j] = calc(A[j],in[i][j]);
		n = SZ(in[i]); Mul2(A,n);
		rep(i,0,2*n-2) {
			now = del(now,mul(inv[2],mul(inv[i+tot],A[i])));
			now = calc(now,mul(inv[i+1+tot],A[i])),A[i] = 0;
		}
	}
	now = del(now,mul(fuck,inv[2]));
	return now;
}
int f3(){
	int now = 0,n = 0;
	rep(i,1,tot) rep(j,0,SZ(in[i])-1) A[j] = calc(A[j],in[i][j]),n = max(n,SZ(in[i]));
	Mul2(A,n); rep(i,0,2*n-2) now = calc(now,mul(inv[i+tot],A[i]));
	return now;
}
void Solve(){
	int ans = 0,now = 0;
	rep(i,1,200000) inv[i] = ksm(i,p-2);
	rep(i,0,work::tot_mx) ans = calc(ans,mul(inv[i+1],work::ans[i])),now = calc(now,mul(inv[i+1+tot],work::ans[i]));
	if(m == n-1){printf("%d\n",ans);exit(0);}
	file(work::inq); rep(i,1,tot) work::inq[tmp[i]] = true;dep[0] = -1;
	rep(i,1,tot){
		work::inq[tmp[i]] = false;mx_dep[i] = -1; tp = i; get_sz(tmp[i],0); work::again(); work::doit(tmp[i],sz[tmp[i]]);
		rep(i,0,work::tot_mx) now = del(now,mul(work::ans[i],inv[i+1+tot]));
	}
	now = mul(now,inv[2]);now = del(f1(),now);now = del(now,f2());
	now = mul(now,2);ans = calc(ans,now);ans = del(ans,f3());
	printf("%d\n",ans);
}
int main(){
	freopen("eat.in","r",stdin);
	freopen("eat.out","w",stdout);
	gi(n);gi(m);int x,y;
	rep(i,1,m) gi(x),gi(y),insert(x,y); if(m == n-1) goto loop;
	cir::find_circle(); dele(tmp[1],tmp[tot]); dele(tmp[tot],tmp[1]);
	loop:; work::doit(1,n);
	Solve();
    return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章