codeforces981H. K Paths
分析
題目大意:樹上選條路徑,要求選擇之後某條邊只能被經過次,且不能沒有經過次的邊。求方案數。
所有被經過次的邊形成的一定是一條樹上的路徑,考慮枚舉路徑的兩個點。考慮子樹的端點選取。要麼放在上,要麼從的兒子的子樹挑一個點。注意一個子樹只能挑一個點。
那麼每個子樹可以挑或者不挑,生成函數之後就是
假設挑了個點在子樹裏面,那麼自然有的方案書把這些節點分配給個端點。
說白了就是
兩邊都可以這麼選,所以總的方案相當於是把兩邊的乘起來。
預處理出,如果兩個點不是祖先關係,直接乘起來,相當於是
這個時候還要再扣掉祖先關係的貢獻,也就是,其中表示子樹的和。
祖孫關係怎麼處理?對於某個子樹方向的子節點,貢獻是
可是每個子節點這個東西的處理是,其中是度數。
但是考慮把相同子樹大小的子節點合併處理,每個節點的子節點的不同的只有
複雜度就是
代碼
#include<bits/stdc++.h>
const int N = 4e5 + 10, P = 998244353;
typedef std::vector<int> VI;
int ri() {
char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int A[N], B[N], R[N], w[N], pr[N], to[N << 1], nx[N << 1], tp;
int fac[N], ivf[N], res[N], f[N], g[N], h[N], sz[N], val[N], fa[N], n, k, L, IvL;
VI c, d; long long ans;
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
int fixd(int x) {return x < 0 ? x + P : x;}
int fixu(int x) {return x >= P ? x - P : x;}
void Inc(int &a, int b) {a = fixu(a + b);}
int Pow(int x, int k) {
int r = 1;
for(;k; x = 1LL * x * x % P, k >>= 1)
if(k & 1)
r = 1LL * r * x % P;
return r;
}
int Iv(int x) {return Pow(x, P - 2);}
void Pre(int m) {
L = 1; int x = 0;
for(;(L <<= 1) < m;) ++x;
for(int i = 1;i < L; ++i)
R[i] = R[i >> 1] >> 1 | (i & 1) << x;
int wn = Pow(3, (P - 1) / L); w[0] = 1;
for(int i = 1;i < L; ++i)
w[i] = 1LL * w[i - 1] * wn % P;
IvL = Iv(L);
}
void NTT(int *F) {
for(int i = 1;i < L; ++i)
if(i < R[i])
std::swap(F[i], F[R[i]]);
for(int i = 1, d = L >> 1; i < L; i <<= 1, d >>= 1)
for(int j = 0;j < L; j += i << 1) {
int *l = F + j, *r = F + i + j, *p = w, tp;
for(int k = i; k--; ++l, ++r, p += d)
tp = 1LL * *r * *p % P, *r = fixd(*l - tp), *l = fixu(*l + tp);
}
}
void Get(VI a, int *A) {
for(int i = 0;i < a.size(); ++i)
A[i] = a[i];
}
VI operator * (VI a, VI b) {
VI c; int n = a.size() + b.size() - 1;
Pre(n);
Get(a, A); Get(b, B);
for(int i = a.size(); i < L; ++i)
A[i] = 0;
for(int i = b.size(); i < L; ++i)
B[i] = 0;
NTT(A); NTT(B);
for(int i = 0;i < L; ++i)
A[i] = 1LL * A[i] * B[i] % P;
NTT(A);
for(int i = 0;i < n; ++i)
c.push_back(1LL * A[L - i & L - 1] * IvL % P);
return c;
}
void operator *= (VI &a, int v) {
a.push_back(0);
for(int i = a.size() - 1;i; --i)
Inc(a[i], 1LL * a[i - 1] * v % P);
}
void operator /= (VI &a, int v) {
static int tp[N]; int n = a.size(), iv = Iv(v); tp[n - 1] = 0;
for(int i = n - 1; i; --i)
tp[i - 1] = 1LL * fixd(a[i] - tp[i]) * iv % P;
a.pop_back();
for(int i = 0;i < n - 1; ++i)
a[i] = tp[i];
}
VI Solve(int L, int R) {
if(L == R) return (VI){1, val[L]}; int m = L + R >> 1;
return Solve(L, m) * Solve(m + 1, R);
}
int PA(int m, int n) {return 1LL * fac[m] * ivf[m - n] % P;}
int Calc(VI a) {
int r = std::min((int)a.size(), k + 1); long long res = 0;
for(int i = 0;i < r; ++i)
res += 1LL * PA(k, i) * a[i] % P;
return res % P;
}
void Dfs(int u, int fa) {
sz[u] = 1; ::fa[u] = fa;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa) {
Dfs(to[i], u);
Inc(g[u], g[to[i]]);
sz[u] += sz[to[i]];
}
int tp = 0;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa)
val[++tp] = sz[to[i]];
if(!tp)
f[u] = 1, Inc(g[u], 1);
else {
std::sort(val + 1, val + tp + 1);
c = Solve(1, tp);
f[u] = Calc(c); Inc(g[u], f[u]);
tp = std::unique(val + 1, val + tp + 1) - val - 1;
for(int i = 1;i <= tp; ++i) {
d = c; d /= val[i]; d *= n - sz[u];
res[val[i]] = Calc(d);
}
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa)
h[to[i]] = res[sz[to[i]]];
}
}
void pre(int n) {
fac[0] = 1;
for(int i = 1;i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % P;
ivf[n] = Iv(fac[n]);
for(int i = n; i; --i)
ivf[i - 1] = 1LL * ivf[i] * i % P;
}
int main() {
n = ri(); k = ri();
if(k == 1) return printf("%lld\n", (1LL * n * (n - 1) >> 1) % P), 0;
pre(std::max(n, k));
for(int i = 1;i < n; ++i)
adds(ri(), ri());
Dfs(1, 0);
for(int i = 1;i <= n; ++i)
ans += f[i];
ans %= P; ans = ans * ans;
for(int i = 1;i <= n; ++i)
ans -= 1LL * f[i] * f[i] % P;
ans %= P; ans = ans * (P + 1 >> 1);
for(int u = 1;u <= n; ++u)
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa[u])
ans += 1LL * (h[to[i]] - f[u]) * g[to[i]] % P;
printf("%d\n", fixd(ans % P));
return 0;
}