2018-2019 ACM-ICPC, 徐州 G. Rikka with Intersections of Paths(樹上差分+lca+容斥)

樹上差分求出一個點被幾條邊覆蓋(num[i]), 然後容斥算貢獻

#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
#include<cstdlib>
#include<cmath>
#include<stack>
#include<map>
#include<string>
#include<vector>
#include<set>
#include<bitset>
#include<algorithm>
#include<ctime>
#include<tr1/unordered_map>
using namespace std;
#define ll long long
#define lb long double
#define INF 0x3f3f3f3f
#define LINF 0x3f3f3f3f3f3f3f3f
#define ull unsigned long long
#define endl '\n'
#define clr(a, b) memset(a, b, sizeof(a))
#define lowbit(x) x & -x
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define PB push_back
#define POP pop_back
#define max_ll 9223372036854775807
#define random(x) (rand()%x)
//srand((ll)time(0));
//freopen("E://one.txt","r",stdin); //輸入重定向,輸入數據將從in.txt文件中讀取
//freopen("E://oneout.txt","w",stdout); //輸出重定向,輸出數據將保存在out.txt文件中
//tr1::unordered_map<int,int>mp;
const double eps = 1e-14;
const double pi = acos(-1);
const int maxn = 3e5 + 10;
const int maxm = (maxn<<5) + 5;
const ll mod = 1e9 + 7;
const int hash_mod = 19260817;
int T, n, m, k;
vector<int> g[maxn];
int lg[maxn], dep[maxn], fa[maxn][40];
ll inv[maxn], fac[maxn], num[maxn], lca_num[maxn];
void dfs(int u, int f){
    dep[u] = dep[f] + 1;
    fa[u][0] = f;
    for(int i = 1 ; (1 << i) <= dep[u] ; ++ i)
        fa[u][i] = fa[fa[u][i-1]][i-1];
    for(int i = 0 ; i < g[u].size() ; ++ i){
        int v = g[u][i];
        if(v != f) dfs(v, u);
    }
}
int lca(int x, int y){
    if(dep[x] < dep[y]) swap(x, y);
    while(dep[x] > dep[y]) x = fa[x][lg[dep[x]-dep[y]]-1];
    if(x == y) return x;
    for(int i = lg[dep[x]] - 1 ; i >= 0 ; -- i){
        if(fa[x][i] != fa[y][i]){
            x = fa[x][i]; y = fa[y][i];
        }
    }
    return fa[x][0];
}
void dfss(int u, int fa){
    for(int i = 0 ; i < g[u].size() ; ++ i){
        int v = g[u][i];
        if(v != fa) dfss(v, u), num[u] += num[v];
    }
}
ll ksm(ll a, ll b){
    ll ans = 1, base = a % mod;
    while(b){
        if(b & 1) ans = (ans * base) % mod;
        base = (base * base) % mod;
        b >>= 1;
    }
    return ans % mod;
}
ll C(int n, int m){
    if(n < m || n < 0 || m < 0) return 0;
    if(n == m || m == 0) return 1;
    return fac[n] * inv[m] % mod * inv[n-m] % mod;
}
int main(){
    scanf("%d", &T);
    for(int i = 1 ; i < maxn ; ++ i){
        lg[i] = lg[i-1] + (1 << lg[i-1] == i);
    }
    fac[0] = 1;
    for(int i = 1 ; i < maxn ; ++ i){
        fac[i] = fac[i-1] * i % mod;
        inv[i] = ksm(fac[i], mod-2);
    }
    int u, v;
    while(T --){
        clr(fa, 0);
        scanf("%d %d %d", &n, &m, &k);
        for(int i = 1 ; i <= n ; ++ i) g[i].clear();
        for(int i = 1 ; i < n ; ++ i){
            scanf("%d %d", &u, &v);
            g[u].PB(v); g[v].PB(u);
            dep[i] = num[i] = lca_num[i] = 0;
        }
        dep[n] = num[n] = lca_num[n] = 0;
        dfs(1, 0);
        for(int i = 1 ; i <= m ; ++ i){
            scanf("%d %d", &u, &v);
            num[u] ++; num[v] ++;
            int w = lca(u, v);
            num[w] --; lca_num[w] ++;
            if(w != 1) num[fa[w][0]] --;
        }
        dfss(1, 0);
        ll ans = 0;
        for(int i = 1 ; i <= n ; ++ i){
            ans = (ans + C(num[i], k) - C(num[i]-lca_num[i], k) + mod) % mod;
        }
        cout << ans << endl;
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章