hdu 5977 Garden of Eden(點分治枚舉路徑模板)

題意:一棵樹,有 n(n≤50000) 個節點,每個點都有一個顏色,共有 k(k≤10) 種顏色,問有多少條路徑可以遍歷到所有 k 種顏色?(一條路徑交換起點終點就算兩條哦)

思路:這個點分治很好想到,不過如何判斷一條路徑上是否包含所有k種顏色是一個問題。但看到k很小,考慮狀態壓縮,最多有(1<<k)-1種狀態,利用點分治可以算出從根節點到所有子節點的路徑,或運算可記錄經過的路徑,將這些路徑狀態記錄,統計每種狀態有多少。然後就是計算有多少路徑包含k種顏色,枚舉數組裏每一個數x,若x與某個數或運算後得(1<<k)-1,也就是((1<<k)-1)^x這個數有多少個,爲了計算x這條路徑上所有的點,這些點與別的路徑匹配後包含k種顏色,這些點我們可以通過枚舉子集來枚舉,然後再與(1<<k)-1進行異或運算,就可以找到了所有的情況。

複雜度O(nlog^{n}2^k)

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1e5 + 10;
int n, k, h[N], cnt, sz[N], rt, vis[N], mx, sn, a[N];
int d[N], haxi[1 << 10];
ll ans;
struct node {
    int v, nt;
} no[N];
void add(int u, int v) {
    no[cnt] = node{v, h[u]};
    h[u] = cnt++;
}
void getroot(int u, int fa) {
    sz[u] = 1;
    int ma = 0;
    for(int i = h[u]; ~i; i = no[i].nt) {
        int v = no[i].v;
        if(!vis[v] && v != fa) {
            getroot(v, u);
            sz[u] += sz[v];
            ma = max(ma, sz[v]);
        }
    }
    ma = max(ma, sn - sz[u]);
    if(mx > ma)
        mx = ma, rt = u;
}
void getsta(int u, int fa, int sta) {
    d[++d[0]] = sta;
    for(int i = h[u]; ~i; i = no[i].nt) {
        int v = no[i].v;
        if(!vis[v] && v != fa)
            getsta(v, u, sta | (1 << a[v]));
    }
}
ll calc(int u, int sta) {
    ll res = 0;
    d[0] = 0;
    memset(haxi, 0, sizeof haxi);
    getsta(u, 0, sta);
    for(int i = 1; i <= d[0]; i++)
        haxi[d[i]]++;
    for(int i = 1; i <= d[0]; i++) {
        haxi[d[i]]--;
        res += haxi[(1 << k) - 1];
        for(int j = d[i]; j; j = (j - 1) & d[i])//枚舉子集
            res += haxi[((1 << k) - 1) ^ j];
        haxi[d[i]]++;
    }
    return res;
}
void dfs(int u) {
    vis[u] = 1, ans += calc(u, 1 << a[u]);
    for(int i = h[u]; ~i; i = no[i].nt) {
        int v = no[i].v;
        if(!vis[v])
            ans -= calc(v, (1 << a[u]) | (1 << a[v])), sn = sz[v], rt = 0, mx = 1e9, getroot(v, 0), dfs(rt);
    }
}
int main() {
    while(~scanf("%d%d", &n, &k)) {
        memset(h, -1, sizeof h);
        memset(vis, 0, sizeof vis);
        ans = 0,  cnt = 0;
        for(int i = 1; i <= n; i++)
            scanf("%d", &a[i]), a[i]--;
        for(int u, v, i = 1; i < n; i++) {
            scanf("%d%d", &u, &v);
            add(u, v), add(v, u);
        }
        sn = n, mx = 1e9, getroot(1, 0), dfs(rt);
        k == 1 ? printf("%lld\n", (ll)n * (ll)n) : printf("%lld\n", ans);
    }
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章