[WC2019] 數樹
題目傳送門
分析
最近老是在肝一些神仙生成函數題。。。哎,肝敗嚇瘋。其實luogu題解裏面的那篇已經很詳細了,這篇題解純屬個人整理,建議是到到luogu題解去看。
題目大意:告訴你有倆棵有標號無根樹,如果某兩個節點共用了某條邊,那麼這兩個點的權值必須相同,點權範圍在內,有三個任務,求在給定2,1,0棵樹的情況下構造樹和點權的方案數。
Task0:簡單轉化
如果兩棵樹都給的話,就是把都存在的邊放在圖中,假設有個連通塊,答案就是,因爲是樹,所以答案就是,其中是邊數。
Task1:各種套路
現在少了一棵樹,這個時候我們就要考慮怎麼形式化問題。
不難發現,由於答案只和重合的邊條數有關係,所以最簡單的想法是枚舉重合的邊集。
其中表示與初始樹重合邊集恰好爲的方案數。
看到我加粗了恰好兩個字就知道我要幹什麼了對嗎:-)
套路1:容斥原理
設表示重合的邊集包含的方案數。
根據容斥原理,可以得到:
套路2:交換求和
這個時候帶回原式化簡一波
發現後面那坨僅僅和集合大小有關係,我們枚舉子集的大小,那麼答案就是
於是我們得到了重要結論:
然而到目前爲止,複雜度仍然是指數級的,瓶頸在於,因此我們要繼續形式化。
套路3:矩陣樹定理
考慮如果我把給你,要怎麼求。
考慮模型化問題,實際上就是欽定了若干條邊,要求任意連邊求生成樹個數。
中的一個連通塊,顯然不能再連邊,因此將他們縮成一個點,任意兩個連通塊實際上有連通塊大小乘積種連邊方案。所以假設中的聯通塊大小分別爲,對於任意兩個連通塊,我們連條邊,形成的圖的生成樹個數就是。
但是我們還是無法避免枚舉,所以我們還得繼續挖掘的性質。索性,將矩陣拿出來玩玩。
套路4:手玩行列式
先去掉一行一列,得到
發現其實每行都有一個可以提一個。得到
發現每行的和都是,因此把列加到第列上,得到
這樣的話,每一列除了對角線上的元素都是相同的。考慮將每行減去第一行,就得到了一個上三角了。
手玩了一陣子的行列式之後,發現答案實際上就是,其中是連通塊個數,是連通塊大小。
因爲是樹上的邊集,所以連通塊的個數
進一步帶入化簡可以得到
爲了方便起見,我們設
考慮將和分配進去
設,現在問題轉化成了,將一棵樹劃分成若干個連通塊,每個連通塊的貢獻是乘上連通塊大小,一個劃分的權值是所有連通塊貢獻之積,求所有劃分權值之和。
這個問題顯然可以用一個解決
套路5:生成函數優化Dp
這個操作我是第一次見。。。。
首先考慮上面的,假設表示以爲根的子樹,所在連通塊大小爲(未計入答案)的劃分權值之和。
考慮子樹合併貢獻方程,假設合併了一個子樹。
兩個方程分別對應切和不切。我們終於得到了一個的算法!
接下來就應該優化這個方程。
這個時候考慮方程的生成函數:
方程可以被簡單地寫成
然後樹鏈剖分+NTT就可以做到兩個log
這個時候考慮我們需要的答案是什麼?
答案的形式如此簡單,因此我們嘗試在轉移中僅僅轉移答案。
設
我們得到了一個的優秀樹!經過重重套路,終於解決了
最終的答案是
Task2:模型使用
發現其實中的容斥是可以用滴!
我們同樣令。但是卻不能再採用上一題的方法,因爲這裏的是任意一個合法的森林的邊集。這和上一題一顆樹的子邊集大不相同。因此考慮採用模型轉化。
梳理一下問題:
某個連通塊的權值爲其大小的平方
若干個連通塊組成的圖的權值是各個連通塊的權值積
求個點的所有不同森林的權值和
這是一個經典的模型。將連通塊看成一個集合,那麼就成爲了若干個關於集合大小的自由組合問題。
考慮一個大小爲的集合的指數型生成函數:
和答案大小爲個點的指數型生成函數:
有
原因是相當於把若干個不同大小的集合拼在一起,再消除內部順序的影響。
對應這道題,大小爲的樹的方案數有中方案,每種方案的權值都是
所以
那麼構造
求
最後的答案就是
本題層次分明,形成了題面,解法的統一和思維層次的不斷螺旋上升,但卻有章法可詢,雖然難點衆多,但卻可以層層分析,層層推導,是難得一見的好題!(Call爆它)
代碼
沉迷封裝,無法自拔。
#include<bits/stdc++.h>
const int N = 524288, P = 998244353;
int ri() {
char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int n, y;
int fix(int x) {return (x >> 31 & P) + x;}
int Pow(int x, int k) {
int r = 1;
for(;k; x = 1LL * x * x % P, k >>= 1)
if(k & 1)
r = 1LL * r * x % P;
return r;
}
int Inv(int x) {return Pow(x, P - 2);}
namespace Solve0 {
std::map<long long, bool> mp;
void Work() {
if(y == 1) return printf("%d\n", Pow(n, n - 2)), void();
int cnt = n;
for(int i = 1;i < n; ++i) {
int u = ri(), v = ri();
if(u > v) std::swap(u, v);
mp[1LL * u * n + v] = true;
}
for(int i = 1;i < n; ++i) {
int u = ri(), v = ri();
if(u > v) std::swap(u, v);
if(mp.count(1LL * u * n + v))
--cnt;
}
printf("%d\n", Pow(y, cnt));
}
}
namespace Solve1 {
int t[N], g[N], k, p, pr[N], to[N], nx[N], tp;
void add(int u, int v) {to[++tp] = v; nx[tp] = pr[u]; pr[u] = tp;}
void adds(int u, int v) {add(u, v); add(v, u);}
void Dp(int u, int fa) {
t[u] = 1; g[u] = k;
for(int i = pr[u]; i; i = nx[i])
if(to[i] != fa) {
Dp(to[i], u); int res = t[to[i]] + g[to[i]];
g[u] = (1LL * g[u] * res + 1LL * t[u] * g[to[i]]) % P;
t[u] = 1LL * t[u] * res % P;
}
}
void Work() {
if(y == 1) return printf("%d\n", Pow(n, (n - 2) % (P - 1))), void();
for(int i = 1;i < n; ++i)
adds(ri(), ri());
p = Inv(y) - 1;
k = 1LL * n * Inv(p) % P;
Dp(1, 0);
printf("%d\n", 1LL * Pow(P + 1 - y, n) * Pow(n, P - 3) % P * g[1] % P);
}
}
namespace Solve2 {
typedef std::vector<int> VI;
int L, InvL, R[N], w[N];
void Pre(int m) {
int x = 0; L = 1;
for(;(L <<= 1) < m;) ++x;
for(int i = 1;i < L; ++i)
R[i] = R[i >> 1] >> 1 | (i & 1) << x;
int wn = Pow(3, (P - 1) / L); w[0] = 1;
for(int i = 1;i < L; ++i)
w[i] = 1LL * w[i - 1] * wn % P;
InvL = Inv(L);
}
void NTT(int *F) {
for(int i = 0;i < L; ++i)
if(R[i] > i)
std::swap(F[i], F[R[i]]);
for(int i = 1, d = L >> 1; i < L; i <<= 1, d >>= 1)
for(int j = 0;j < L; j += i << 1) {
int *l = F + j, *r = F + i + j, *p = w, tp;
for(int k = i; k--; ++l, ++r, p += d)
tp = 1LL * *r * *p % P, *r = (*l - tp) % P, *l = (*l + tp) % P;
}
}
void Fill(const VI &a, int *A, int m) {
m = std::min(m, (int)a.size());
for(int i = 0;i < m; ++i)
A[i] = a[i];
for(int i = m; i < L; ++i)
A[i] = 0;
}
void Fill(int *A, int *B, int m) {
for(int i = 0;i < m; ++i)
B[i] = A[i];
for(int i = m; i < L; ++i)
B[i] = 0;
}
VI operator * (const VI &a, const VI &b) {
const int Lim = 3000;
int asz = a.size(), bsz = b.size(), m = asz + bsz - 1;
static VI c; c.resize(m);
if(1LL * asz * bsz <= Lim) {
for(int i = 0;i < m; ++i)
c[i] = 0;
for(int i = 0;i < asz; ++i)
for(int j = 0;j < bsz; ++j)
c[i + j] = (c[i + j] + 1LL * a[i] * b[j]) % P;
return c;
}
Pre(m); static int A[N], B[N];
Fill(a, A, asz); Fill(b, B, bsz);
NTT(A); NTT(B);
for(int i = 0;i < L; ++i)
A[i] = (1LL * A[i] * B[i]) % P;
NTT(A);
for(int i = 0;i < m; ++i)
c[i] = fix(1LL * A[L - i & L - 1] * InvL % P);
return c;
}
VI Inv(const VI &a, int m) {
static int A[N], B[N], C[N];
for(int i = 0;i < m; ++i)
A[i] = 0;
A[0] = ::Inv(a[0]); int n = 1;
for(;n < m;) {
Pre(n << 2);
Fill(A, B, n);
Fill(a, C, n << 1);
NTT(B); NTT(C);
for(int i = 0;i < L; ++i)
B[i] = 1LL * B[i] * B[i] % P * C[i] % P;
NTT(B);
n <<= 1;
for(int i = 0; i < n; ++i)
A[i] = ((A[i] << 1) - 1LL * B[L - i & L - 1] * InvL) % P;
}
static VI c; c.resize(m);
for(int i = 0;i < m; ++i)
c[i] = fix(A[i]);
return c;
}
VI deri(const VI &a) {
int n = a.size();
if(n == 1)
return VI(1, 0);
static VI c; c.resize(n - 1);
for(int i = 1;i < n; ++i)
c[i - 1] = 1LL * a[i] * i % P;
return c;
}
VI inte(const VI &a) {
int n = a.size();
static VI c; c.resize(n + 1);
for(int i = 1;i <= n; ++i)
c[i] = 1LL * a[i - 1] * ::Inv(i) % P;
c[0] = 0;
return c;
}
VI Ln(const VI &a, int m) {
static VI f;
f = deri(a) * Inv(a, m - 1);
f.resize(m - 1);
return inte(f);
}
VI Exp(const VI &a, int m) {
static VI f, g; f.resize(1); f[0] = 1;
int n = 1, asz = a.size();
for(;n < m;) {
n <<= 1;
g = Ln(f, n);
for(int i = 0;i < n; ++i)
g[i] = i < asz ? fix(-g[i] + a[i]) : fix(-g[i]);
(++g[0]) %= P;
f = f * g;
f.resize(n);
}
f.resize(m);
return f;
}
VI f; int ivf[N], fac[N], p, k;
void Work() {
if(y == 1) return printf("%d\n", Pow(n, (n - 2 << 1) % (P - 1))), void();
fac[0] = 1;
for(int i = 1;i <= n; ++i)
fac[i] = 1LL * fac[i - 1] * i % P;
ivf[n] = ::Inv(fac[n]);
for(int i = n; i; --i)
ivf[i - 1] = 1LL * ivf[i] * i % P;
p = ::Inv(y) - 1;
k = 1LL * n * n % P * ::Inv(p) % P;
f.push_back(0);
for(int i = 1;i <= n; ++i)
f.push_back(1LL * k * Pow(i, i) % P * ivf[i] % P);
f = Exp(f, n + 1);
printf("%d\n", 1LL * Pow(1LL * p * y % P, n) * Pow(n, P - 5) % P * f[n] % P * fac[n] % P);
}
}
int main() {
n = ri(); y = ri(); int op = ri();
if(!op) Solve0::Work();
else if(op == 1) Solve1::Work();
else Solve2::Work();
return 0;
}