點分治學習/模板(5道例題)

感謝b站up大佬:不分解的AgOH

點分治用於樹上的大規模路徑操作、統計。它的靈活性高,適用範圍廣。很多樹形dp也可以用點分治搞。其實,樹分治中,還有一個邊分治,不過沒有點分治常用。
基本步驟:

第一步

找樹的重心。我們重心爲根節點,然後按題目的不同情況統計樹各個點到根節點的路徑信息。

第二步

遞歸至每一個子樹,重複第一步,直到分割至葉子節點。

非常簡潔。
下面看一看具體的實現:
找重心函數:
找到原以v爲根的子樹的重心,並切換根。

void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;//siz數組數組樹子樹大小
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];//mp數組是去掉u節點後,剩餘部分的最大一部分。
    }
    mp[u] = max(mp[u], sum-siz[u]);//不要忘記上子樹
    if (mp[u] < mp[rt]) rt = u;//換根
}

分割函數:
對原樹進行劃分

//------main----
getrt(1, 0);//找整樹的重心
getrt(rt, 0);//爲啥是兩遍?因爲我們需要讓siz數組正確(換根後siz數組就不正確了)
divide(rt);
//--------------
void divide(int u)
{
    vis[u] = 1;//vis[i],表示u節點已被選中
    solve(u);//已u爲根節點,統計路徑信息
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);//找v子樹的根節點
        getrt(rt, 0);
        divide(rt);//遞歸劃分子樹
    }
}

統計信息的函數,隨題目而異,這裏引出5道例題。
例1:CF161D Distance in Tree
最基本的樹分治,求樹上距離爲k的點的對數。
直接上代碼,注意看solve函數:
下面是ac代碼:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans, k;
int tmp[N], siz[N], dis[N], mp[N], jd[N*10];
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt++] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<int> que;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        getdis(v, u);//統計v子樹的所有節點到v的距離
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j])
                    ans += jd[k-tmp[j]];//jd數組是一個桶數組,jd[i]爲路徑i的數量
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]]++;//每找一個子樹就壓進去一批
        }
    }
    while(que.size())
    {
        jd[que.front()]--;
        que.pop();//以u爲根的子樹統計完畢,清空。
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y, 1); add(y, x, 1);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    printf("%d\n", ans);
    return 0;
}

例2:洛谷 P3806 【模板】點分治1
統計距離k的點是否存在,離線後,和上一個題基本一樣。挨個統計,挨個比較每一個詢問。
下面是ac代碼:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt;
int tmp[N], siz[N], dis[N], mp[N], q[105];
bool jd[N*10], ans[105], vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt++] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<int> que;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
            for (int k = 0; k <m; k++)
                if (q[k] >= tmp[j])
                    ans[k] |= jd[q[k]-tmp[j]];
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]] = 1;
        }
    }
    while(que.size())
    {
        jd[que.front()] = 0;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    for (int i = 0; i < m; i++)
        scanf("%d", &q[i]);
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    for (int i = 0; i < m; i++)
    {
        if (ans[i]) puts("AYE");
        else puts("NAY");
    }
    return 0;
}

例3:洛谷 P4178 Tree
這個題要求我們統計有多少對點之間的距離小於等於k的。這時我們發現,solve函數中更新ans的複雜度不太理想(因爲要對於一個距離g需要累計從0至k-g),不過好在我們可以用樹狀數組維護桶數組jd達到我們的目的。
下面是ac代碼:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[N*100];
bool vis[N];
void change(int x, int y)
{
    for (;x <= 100*N-2; x += x & -x) jd[x] += y;
}
ll ask(int x)
{
    ll ans = 0;
    for (; x; x -= x &-x) ans += jd[x];
    return ans;
}
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[++cnt] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
   // cout << u << " ::::" << endl;
    queue<int> que;
    que.push(0);
    change(1, 1);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
       // cout << v <<":";
        getdis(v, u);
      //  for (int i = 1; i <= cnt; i++)
      //      cout << tmp[i] << " ";
     //   cout << endl;
        for (int j = 1; j <= cnt; j++)
        {
            if (tmp[j] > k) continue;
            ans += ask(k-tmp[j]+1);
        }
        for (int j = 1; j <= cnt; j++)
        {
            que.push(tmp[j]);
            change(tmp[j]+1, 1);
        }
    }
    while(que.size())
    {
        change(que.front()+1, -1);
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    scanf("%d", &k);
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    printf("%d\n", ans);
    return 0;
}

例4:洛谷 P2634 [國家集訓隊]聰聰可可
也是比較基礎的問題,就是統計所有距離是3的倍數的對數。然後,值得注意的是,點分治沒有考慮兩個點重合的情況,看題目樣例,我們是需要討論的。好在兩個點重合(dis==0)都符合題目要求,我們之間吧ans加上一個n就ok了。
下面是ac代碼:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[5];
bool vis[N];
int gcd(int a, int b)
{
    return b?gcd(b,a%b):a;
}
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[++cnt] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
   // cout << u << " ::::" << endl;
    queue<int> que;
    que.push(0);
    jd[0] = 1;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
       // cout << v <<":";
        getdis(v, u);
      //  for (int i = 1; i <= cnt; i++)
      //      cout << tmp[i] << " ";
     //   cout << endl;
        for (int j = 1; j <= cnt; j++)
        {
            if(tmp[j]%3 == 0) ans += jd[0];
            else if (tmp[j]%3 == 1) ans += jd[2];
            else ans += jd[1];
        }
        for (int j = 1; j <= cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]%3]++;
        }
    }
    while(que.size())
    {
        jd[que.front()%3]--;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    ans *= 2;
    ans += n;
    int gg = n *n;
    int d = gcd(ans, gg);
    printf("%d/%d\n", ans/d, gg/d);
    return 0;
}

例5:洛谷 P4149 [IOI2011]Race
比較有意思,求出路徑爲k的點對中,邊數最小的數量。不難,我們在getdis的函數裏同時跑出深度dep。然後在路徑爲k的情況下最小化ans。但是,,但是!這狗題沒說權值範圍。。。。沒法看了1e6的jd數組交了一份re了3個點。隨改成2e7還是re了一個點,在開就mlt了。。。。用了unormap,,t了3個點我擦。。。然後改回jd數組強行hash了一波。。可算是過了。不過這也不是正經解法啊。。於是用unormap吸着氧2.6s險些超時。。。兩份代碼都放一下吧。。可能有正經解法,本憨批不知道。
下面是ac代碼(數組強行hash):

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int mod = N *200;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], jd[N*200], dep[N];
pair<int, int> tmp[N];
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt].first = dep[u];
    tmp[cnt++].second = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        dep[v] = dep[u] + 1;
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<pair<int, int> > que;
    que.push(make_pair(0, 0));
    jd[0] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        dep[v] = 1;
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j].second)
                    ans = min(ans, jd[k-tmp[j].second%mod] + tmp[j].first);
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j].second%mod] = min(jd[tmp[j].second%mod], tmp[j].first);
        }
    }
    while(que.size())
    {
        jd[que.front().second%mod] = inf;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x+1, y+1, w); add(y+1, x+1, w);
    }
    memset(jd, inf, sizeof(jd));
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    if (ans != inf)
        printf("%d\n", ans);
    else
        puts("-1");
    return 0;
}

(吸氧map):

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <unordered_map>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], dep[N];
pair<int, int> tmp[N];
unordered_map<int, int> jd;
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt].first = dep[u];
    tmp[cnt++].second = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        dep[v] = dep[u] + 1;
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<pair<int, int> > que;
    que.push(make_pair(0, 0));
    jd[0] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        dep[v] = 1;
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j].second)
                {
                    if (jd.find(k-tmp[j].second) == jd.end()) jd[k-tmp[j].second] = inf;
                    ans = min(ans, jd[k-tmp[j].second] + tmp[j].first);
                }
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            if (jd.find(tmp[j].second) == jd.end()) jd[tmp[j].second] = inf;
            jd[tmp[j].second] = min(jd[tmp[j].second], tmp[j].first);
        }
    }
    while(que.size())
    {
        jd[que.front().second] = inf;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x+1, y+1, w); add(y+1, x+1, w);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    if (ans != inf)
        printf("%d\n", ans);
    else
        puts("-1");
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章