2018-2019 ACM-ICPC, Asia Xuzhou Regional Contest G. Rikka with Intersections of Paths(樹上差分+LCA+容斥)

題目鏈接:http://codeforces.com/gym/102012/problem/G

題目大意:有一棵n個結點的樹,現在給出m條樹上的路徑。現在要從這m條路徑中選出k條路徑,使得這k條路徑至少有一個公共交點,問你總共有多少種方案數。

題目思路:(今年徐州現場的銀牌題,我們隊肝到最後也沒能肝出來,錯失了銀牌。。。QAQ,當時忘了一個重要的性質,導致正思路都錯了。還是太菜了)

感慨一下,繼續分析題目。

解決這個題,需要用到一個重要的性質:一個樹上任意兩條路徑如果有交點的話,那麼這些交點中肯定有一個爲兩條路徑中的一條路徑兩端點的lca

有了這個性質的話,我們可以對通過枚舉路徑的交點來求答案。

對於每個節點,我們假設通過這個節點的路徑有M條,以這個點爲LCA且通過這個節點的路徑有N條。

那麼在這個節點對答案的貢獻爲:C_{M}^{K}-C_{M-N}^{K}。這個式子計算出來的是,從通過這個節點的路徑中選出k條路徑,且至少有一條路徑的LCA爲這個節點的方案數,這樣選的話就不會出現重複選的情況了,因爲至少有一條路徑以該節點爲LCA,在以其他點爲交點的時候就不會重複計算了。

而通過某個結點的路徑數我們可以通過樹上差分計算,假設通過u這個節點的路徑爲sum[u]。那麼在更新路徑[u,v]的時候,我們就令sum[u]++,sum[v]++,sum[lca(u,v)]--,sum[fa[lca(u,v)]]--。接着再用dfs一遍即可。

具體實現看代碼:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int MX = 3e5 + 5;
const int mod = 1e9 + 7;

int n, m, k;
struct edge {int v, w, nxt;} E[MX << 1];
int head[MX], tot;
int dep[MX], ST[MX][20];
void add_edge(int u, int v) {
    E[tot].v = v; E[tot].nxt = head[u];
    head[u] = tot++;
}
void dfs(int u, int d, int fa) {
    dep[u] = d; ST[u][0] = fa;
    for (int i = head[u]; ~i; i = E[i].nxt) {
        int v = E[i].v;
        if (v == fa) continue;
        dfs(v, d + 1, u);
    }
}
void pre_solve() {
    dfs(1, 0, 1);
    for (int i = 1; i < 20; i++) {
        for (int j = 1; j <= n; j++) {
            ST[j][i] = ST[ST[j][i - 1]][i - 1];
        }
    }
}
int LCA(int u, int v) {
    while (dep[u] != dep[v]) {
        if (dep[u] < dep[v]) swap(u, v);
        int d = dep[u] - dep[v];
        for (int i = 0; i < 20; i++)
            if (d >> i & 1)u = ST[u][i];
    }
    if (u == v) return u;
    for (int i = 19; i >= 0; i--) {
        if (ST[u][i] != ST[v][i]) {
            u = ST[u][i];
            v = ST[v][i];
        }
    }
    return ST[u][0];
}
int sum[MX], lca_sum[MX];
void solve(int u, int fa) {
    for (int i = head[u]; ~i; i = E[i].nxt) {
        int v = E[i].v;
        if (v == fa) continue;
        solve(v, u);
        sum[u] += sum[v];
    }
}

ll f[MX], inv[MX];
ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1) res = (res * a) % mod;
        a = (a * a) % mod;
        b >>= 1;
    }
    return res;
}
void init() {
    f[1] = 1;
    for (int i = 2; i < MX; i++) f[i] = (f[i - 1] * i) % mod;
    inv[MX - 1] = qpow(f[MX - 1], mod - 2);
    for (int i = MX - 2; i >= 1; i--) inv[i] = (inv[i + 1] * (i + 1)) % mod;
}
ll C(int n, int m) {
    if (n < 0 || m < 0 || m > n) return 0;
    if (m == 0 || m == n) return 1;
    return f[n] * inv[n - m] % mod * inv[m] % mod;
}

int main() {
    // FIN;
    init();
    int T; cin >> T;
    while (T--) {
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 1; i <= n; i++) head[i] = -1;
        tot = 0;
        for (int i = 1; i < n; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            add_edge(u, v); add_edge(v, u);
        }
        pre_solve();
        for (int i = 1; i <= m; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            int lca = LCA(u, v); lca_sum[lca]++;
            sum[u]++; sum[v]++;
            sum[lca]--;
            if (lca != 1) sum[ST[lca][0]]--;
        }
        solve(1, 0);
        ll ans = 0;
        for (int i = 1; i <= n; i++)
            ans = (ans % mod + ((C(sum[i], k) - C(sum[i] - lca_sum[i], k)) % mod + mod) % mod) % mod;
        printf("%lld\n", ans);
    }
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章