点分治学习/模板(5道例题)

感谢b站up大佬:不分解的AgOH

点分治用于树上的大规模路径操作、统计。它的灵活性高,适用范围广。很多树形dp也可以用点分治搞。其实,树分治中,还有一个边分治,不过没有点分治常用。
基本步骤:

第一步

找树的重心。我们重心为根节点,然后按题目的不同情况统计树各个点到根节点的路径信息。

第二步

递归至每一个子树,重复第一步,直到分割至叶子节点。

非常简洁。
下面看一看具体的实现:
找重心函数:
找到原以v为根的子树的重心,并切换根。

void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;//siz数组数组树子树大小
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];//mp数组是去掉u节点后,剩余部分的最大一部分。
    }
    mp[u] = max(mp[u], sum-siz[u]);//不要忘记上子树
    if (mp[u] < mp[rt]) rt = u;//换根
}

分割函数:
对原树进行划分

//------main----
getrt(1, 0);//找整树的重心
getrt(rt, 0);//为啥是两遍?因为我们需要让siz数组正确(换根后siz数组就不正确了)
divide(rt);
//--------------
void divide(int u)
{
    vis[u] = 1;//vis[i],表示u节点已被选中
    solve(u);//已u为根节点,统计路径信息
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);//找v子树的根节点
        getrt(rt, 0);
        divide(rt);//递归划分子树
    }
}

统计信息的函数,随题目而异,这里引出5道例题。
例1:CF161D Distance in Tree
最基本的树分治,求树上距离为k的点的对数。
直接上代码,注意看solve函数:
下面是ac代码:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans, k;
int tmp[N], siz[N], dis[N], mp[N], jd[N*10];
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt++] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<int> que;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        getdis(v, u);//统计v子树的所有节点到v的距离
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j])
                    ans += jd[k-tmp[j]];//jd数组是一个桶数组,jd[i]为路径i的数量
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]]++;//每找一个子树就压进去一批
        }
    }
    while(que.size())
    {
        jd[que.front()]--;
        que.pop();//以u为根的子树统计完毕,清空。
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y, 1); add(y, x, 1);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    printf("%d\n", ans);
    return 0;
}

例2:洛谷 P3806 【模板】点分治1
统计距离k的点是否存在,离线后,和上一个题基本一样。挨个统计,挨个比较每一个询问。
下面是ac代码:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt;
int tmp[N], siz[N], dis[N], mp[N], q[105];
bool jd[N*10], ans[105], vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt++] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<int> que;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
            for (int k = 0; k <m; k++)
                if (q[k] >= tmp[j])
                    ans[k] |= jd[q[k]-tmp[j]];
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]] = 1;
        }
    }
    while(que.size())
    {
        jd[que.front()] = 0;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    for (int i = 0; i < m; i++)
        scanf("%d", &q[i]);
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    for (int i = 0; i < m; i++)
    {
        if (ans[i]) puts("AYE");
        else puts("NAY");
    }
    return 0;
}

例3:洛谷 P4178 Tree
这个题要求我们统计有多少对点之间的距离小于等于k的。这时我们发现,solve函数中更新ans的复杂度不太理想(因为要对于一个距离g需要累计从0至k-g),不过好在我们可以用树状数组维护桶数组jd达到我们的目的。
下面是ac代码:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[N*100];
bool vis[N];
void change(int x, int y)
{
    for (;x <= 100*N-2; x += x & -x) jd[x] += y;
}
ll ask(int x)
{
    ll ans = 0;
    for (; x; x -= x &-x) ans += jd[x];
    return ans;
}
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[++cnt] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
   // cout << u << " ::::" << endl;
    queue<int> que;
    que.push(0);
    change(1, 1);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
       // cout << v <<":";
        getdis(v, u);
      //  for (int i = 1; i <= cnt; i++)
      //      cout << tmp[i] << " ";
     //   cout << endl;
        for (int j = 1; j <= cnt; j++)
        {
            if (tmp[j] > k) continue;
            ans += ask(k-tmp[j]+1);
        }
        for (int j = 1; j <= cnt; j++)
        {
            que.push(tmp[j]);
            change(tmp[j]+1, 1);
        }
    }
    while(que.size())
    {
        change(que.front()+1, -1);
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    scanf("%d", &k);
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    printf("%d\n", ans);
    return 0;
}

例4:洛谷 P2634 [国家集训队]聪聪可可
也是比较基础的问题,就是统计所有距离是3的倍数的对数。然后,值得注意的是,点分治没有考虑两个点重合的情况,看题目样例,我们是需要讨论的。好在两个点重合(dis==0)都符合题目要求,我们之间吧ans加上一个n就ok了。
下面是ac代码:

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[5];
bool vis[N];
int gcd(int a, int b)
{
    return b?gcd(b,a%b):a;
}
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[++cnt] = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        getdis(v, u);
    }
}
void solve(int u)
{
   // cout << u << " ::::" << endl;
    queue<int> que;
    que.push(0);
    jd[0] = 1;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
       // cout << v <<":";
        getdis(v, u);
      //  for (int i = 1; i <= cnt; i++)
      //      cout << tmp[i] << " ";
     //   cout << endl;
        for (int j = 1; j <= cnt; j++)
        {
            if(tmp[j]%3 == 0) ans += jd[0];
            else if (tmp[j]%3 == 1) ans += jd[2];
            else ans += jd[1];
        }
        for (int j = 1; j <= cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j]%3]++;
        }
    }
    while(que.size())
    {
        jd[que.front()%3]--;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x, y, w); add(y, x, w);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    ans *= 2;
    ans += n;
    int gg = n *n;
    int d = gcd(ans, gg);
    printf("%d/%d\n", ans/d, gg/d);
    return 0;
}

例5:洛谷 P4149 [IOI2011]Race
比较有意思,求出路径为k的点对中,边数最小的数量。不难,我们在getdis的函数里同时跑出深度dep。然后在路径为k的情况下最小化ans。但是,,但是!这狗题没说权值范围。。。。没法看了1e6的jd数组交了一份re了3个点。随改成2e7还是re了一个点,在开就mlt了。。。。用了unormap,,t了3个点我擦。。。然后改回jd数组强行hash了一波。。可算是过了。不过这也不是正经解法啊。。于是用unormap吸着氧2.6s险些超时。。。两份代码都放一下吧。。可能有正经解法,本憨批不知道。
下面是ac代码(数组强行hash):

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int mod = N *200;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], jd[N*200], dep[N];
pair<int, int> tmp[N];
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt].first = dep[u];
    tmp[cnt++].second = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        dep[v] = dep[u] + 1;
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<pair<int, int> > que;
    que.push(make_pair(0, 0));
    jd[0] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        dep[v] = 1;
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j].second)
                    ans = min(ans, jd[k-tmp[j].second%mod] + tmp[j].first);
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            jd[tmp[j].second%mod] = min(jd[tmp[j].second%mod], tmp[j].first);
        }
    }
    while(que.size())
    {
        jd[que.front().second%mod] = inf;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x+1, y+1, w); add(y+1, x+1, w);
    }
    memset(jd, inf, sizeof(jd));
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    if (ans != inf)
        printf("%d\n", ans);
    else
        puts("-1");
    return 0;
}

(吸氧map):

#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <unordered_map>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], dep[N];
pair<int, int> tmp[N];
unordered_map<int, int> jd;
bool vis[N];
void add(int x, int y, int w)
{
    ver[++tot] = y;
    ne[tot] = he[x];
    e[tot] = w;
    he[x] = tot;
}
void getrt(int u, int f)
{
    siz[u] = 1; mp[u] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if (siz[v] > mp[u]) mp[u] = siz[v];
    }
    mp[u] = max(mp[u], sum-siz[u]);
    if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
    tmp[cnt].first = dep[u];
    tmp[cnt++].second = dis[u];
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == f || vis[v]) continue;
        dis[v] = dis[u] + e[i];
        dep[v] = dep[u] + 1;
        getdis(v, u);
    }
}
void solve(int u)
{
    queue<pair<int, int> > que;
    que.push(make_pair(0, 0));
    jd[0] = 0;
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (vis[v]) continue;
        cnt = 0;
        dis[v] = e[i];
        dep[v] = 1;
        getdis(v, u);
        for (int j = 0; j < cnt; j++)
                if (k >= tmp[j].second)
                {
                    if (jd.find(k-tmp[j].second) == jd.end()) jd[k-tmp[j].second] = inf;
                    ans = min(ans, jd[k-tmp[j].second] + tmp[j].first);
                }
        for (int j = 0; j < cnt; j++)
        {
            que.push(tmp[j]);
            if (jd.find(tmp[j].second) == jd.end()) jd[tmp[j].second] = inf;
            jd[tmp[j].second] = min(jd[tmp[j].second], tmp[j].first);
        }
    }
    while(que.size())
    {
        jd[que.front().second] = inf;
        que.pop();
    }
}
void divide(int u)
{
    vis[u] = jd[0] = 1;
    solve(u);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if(vis[v]) continue;
        mp[rt=0] = sum = siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++)
    {
        int x, y, w;
        scanf("%d%d%d", &x, &y, &w);
        add(x+1, y+1, w); add(y+1, x+1, w);
    }
    mp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    if (ans != inf)
        printf("%d\n", ans);
    else
        puts("-1");
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章