codeforces981H. K Paths NTT樹形Dp

codeforces981H. K Paths

題目鏈接

分析

題目大意:樹上選kk條路徑,要求選擇之後某條邊只能被經過0,1,k0,1,k次,且不能沒有經過kk次的邊。求方案數。

所有被經過kk次的邊形成的一定是一條樹上的路徑,考慮枚舉路徑的兩個點u,vu,v。考慮uu子樹的端點選取。要麼放在uu上,要麼從uu的兒子的子樹挑一個點。注意一個子樹只能挑一個點。
那麼每個子樹可以挑或者不挑,生成函數之後就是c(x)=(1+szsonux)c(x)=\prod (1+sz_{son_u}x)
假設挑了ii個點在子樹裏面,那麼自然有AkiA_k^i的方案書把這些節點分配給kk個端點。
說白了就是Ansu=Akic(i)Ans_u=\sum A_k^ic(i)
兩邊都可以這麼選,所以總的方案相當於是把兩邊的AnsAns乘起來。
預處理出AnsuAns_u,如果兩個點不是祖先關係,直接乘起來,相當於是12[(Ansu)2Ansu2]\frac{1}{2}[(\sum Ans_u)^2-\sum Ans_u^2]
這個時候還要再扣掉祖先關係的貢獻,也就是Ansugu\sum Ans_u g_u,其中gug_u表示子樹的AnsAns和。
祖孫關係怎麼處理?對於某個子樹方向的子節點vv,貢獻是c(x)1+(nsz[u])x1+sz[v]xc(x)\frac{1+(n-sz[u])x}{1+sz[v]x}
可是每個子節點這個東西的處理是O(d)O(d),其中dd是度數。
但是考慮把相同子樹大小的子節點合併處理,每個節點的子節點的不同的szsz只有O(n)O(\sqrt n)
複雜度就是O(nlogn+nn)O(nlogn+n\sqrt n)

代碼

#include<bits/stdc++.h>
const int N = 4e5 + 10, P = 998244353;
typedef std::vector<int> VI;
int ri() {
	char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
	for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int A[N], B[N], R[N], w[N], pr[N], to[N << 1], nx[N << 1], tp;
int fac[N], ivf[N], res[N], f[N], g[N], h[N], sz[N], val[N], fa[N], n, k, L, IvL;
VI c, d; long long ans;
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
int fixd(int x) {return x < 0 ? x + P : x;}
int fixu(int x) {return x >= P ? x - P : x;}
void Inc(int &a, int b) {a = fixu(a + b);}
int Pow(int x, int k) {
	int r = 1;
	for(;k; x = 1LL * x * x % P, k >>= 1)
		if(k & 1)
			r = 1LL * r * x % P;
	return r;
}
int Iv(int x) {return Pow(x, P - 2);}
void Pre(int m) {
	L = 1; int x = 0;
	for(;(L <<= 1) < m;) ++x;
	for(int i = 1;i < L; ++i)
		R[i] = R[i >> 1] >> 1 | (i & 1) << x;
	int wn = Pow(3, (P - 1) / L); w[0] = 1;
	for(int i = 1;i < L; ++i)
		w[i] = 1LL * w[i - 1] * wn % P;
	IvL = Iv(L);
}
void NTT(int *F) {
	for(int i = 1;i < L; ++i)
		if(i < R[i])
			std::swap(F[i], F[R[i]]);
	for(int i = 1, d = L >> 1; i < L; i <<= 1, d >>= 1) 
		for(int j = 0;j < L; j += i << 1) {
			int *l = F + j, *r = F + i + j, *p = w, tp;
			for(int k = i; k--; ++l, ++r, p += d)
				tp = 1LL * *r * *p % P, *r = fixd(*l - tp), *l = fixu(*l + tp);
		}
}
void Get(VI a, int *A) {
	for(int i = 0;i < a.size(); ++i)
		A[i] = a[i];
}
VI operator * (VI a, VI b) {
	VI c; int n = a.size() + b.size() - 1;
	Pre(n);
	Get(a, A); Get(b, B);
	for(int i = a.size(); i < L; ++i)
		A[i] = 0;
	for(int i = b.size(); i < L; ++i)
		B[i] = 0;
	NTT(A); NTT(B);
	for(int i = 0;i < L; ++i)
		A[i] = 1LL * A[i] * B[i] % P;
	NTT(A);
	for(int i = 0;i < n; ++i)
		c.push_back(1LL * A[L - i & L - 1] * IvL % P);
	return c;
}
void operator *= (VI &a, int v) {
	a.push_back(0);  
	for(int i = a.size() - 1;i; --i)
		Inc(a[i], 1LL * a[i - 1] * v % P);
}
void operator /= (VI &a, int v) { 
	static int tp[N]; int n = a.size(), iv = Iv(v); tp[n - 1] = 0;
	for(int i = n - 1; i; --i)
		tp[i - 1] = 1LL * fixd(a[i] - tp[i]) * iv % P;
	a.pop_back(); 
	for(int i = 0;i < n - 1; ++i)
		a[i] = tp[i];
}
VI Solve(int L, int R) {
	if(L == R) return (VI){1, val[L]}; int m = L + R >> 1;
	return Solve(L, m) * Solve(m + 1, R);
}
int PA(int m, int n) {return 1LL * fac[m] * ivf[m - n] % P;}
int Calc(VI a) {
	int r = std::min((int)a.size(), k + 1); long long res = 0;
	for(int i = 0;i < r; ++i)
		res += 1LL * PA(k, i) * a[i] % P;
	return res % P;	
}
void Dfs(int u, int fa) {
	sz[u] = 1; ::fa[u] = fa;
	for(int i = pr[u]; i; i = nx[i])
		if(to[i] != fa) {
			Dfs(to[i], u);
			Inc(g[u], g[to[i]]);
			sz[u] += sz[to[i]];
		} 
	int tp = 0;
	for(int i = pr[u]; i; i = nx[i])
		if(to[i] != fa)
			val[++tp] = sz[to[i]];
	if(!tp)
		f[u] = 1, Inc(g[u], 1);
	else {
		std::sort(val + 1, val + tp + 1);
		c = Solve(1, tp);
		f[u] = Calc(c); Inc(g[u], f[u]);
		tp = std::unique(val + 1, val + tp + 1) - val - 1;
		for(int i = 1;i <= tp; ++i) {
			d = c; d /= val[i]; d *= n - sz[u];
			res[val[i]] = Calc(d);
		}
		for(int i = pr[u]; i; i = nx[i])
			if(to[i] != fa)
				h[to[i]] = res[sz[to[i]]];
	}
}
void pre(int n) {
	fac[0] = 1;
	for(int i = 1;i <= n; ++i)
		fac[i] = 1LL * fac[i - 1] * i % P;
	ivf[n] = Iv(fac[n]);
	for(int i = n; i; --i)
		ivf[i - 1] = 1LL * ivf[i] * i % P;
}
int main() {
	n = ri(); k = ri();
	if(k == 1) return printf("%lld\n", (1LL * n * (n - 1) >> 1) % P), 0; 
	pre(std::max(n, k));
	for(int i = 1;i < n; ++i)
		adds(ri(), ri());
	Dfs(1, 0);
	for(int i = 1;i <= n; ++i)
		ans += f[i];
	ans %= P; ans = ans * ans;
	for(int i = 1;i <= n; ++i)
		ans -= 1LL * f[i] * f[i] % P;
	ans %= P; ans = ans * (P + 1 >> 1);
	for(int u = 1;u <= n; ++u)
		for(int i = pr[u]; i; i = nx[i])
		if(to[i] != fa[u])
			ans += 1LL * (h[to[i]] - f[u]) * g[to[i]] % P;
	printf("%d\n", fixd(ans % P));
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章