感謝b站up大佬:不分解的AgOH。
點分治用於樹上的大規模路徑操作、統計。它的靈活性高,適用範圍廣。很多樹形dp也可以用點分治搞。其實,樹分治中,還有一個邊分治,不過沒有點分治常用。
基本步驟:
第一步
找樹的重心。我們重心爲根節點,然後按題目的不同情況統計樹各個點到根節點的路徑信息。
第二步
遞歸至每一個子樹,重複第一步,直到分割至葉子節點。
非常簡潔。
下面看一看具體的實現:
找重心函數:
找到原以v爲根的子樹的重心,並切換根。
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;//siz數組數組樹子樹大小
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];//mp數組是去掉u節點後,剩餘部分的最大一部分。
}
mp[u] = max(mp[u], sum-siz[u]);//不要忘記上子樹
if (mp[u] < mp[rt]) rt = u;//換根
}
分割函數:
對原樹進行劃分
//------main----
getrt(1, 0);//找整樹的重心
getrt(rt, 0);//爲啥是兩遍?因爲我們需要讓siz數組正確(換根後siz數組就不正確了)
divide(rt);
//--------------
void divide(int u)
{
vis[u] = 1;//vis[i],表示u節點已被選中
solve(u);//已u爲根節點,統計路徑信息
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);//找v子樹的根節點
getrt(rt, 0);
divide(rt);//遞歸劃分子樹
}
}
統計信息的函數,隨題目而異,這裏引出5道例題。
例1:CF161D Distance in Tree
最基本的樹分治,求樹上距離爲k的點的對數。
直接上代碼,注意看solve函數:
下面是ac代碼:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans, k;
int tmp[N], siz[N], dis[N], mp[N], jd[N*10];
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt++] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
queue<int> que;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
getdis(v, u);//統計v子樹的所有節點到v的距離
for (int j = 0; j < cnt; j++)
if (k >= tmp[j])
ans += jd[k-tmp[j]];//jd數組是一個桶數組,jd[i]爲路徑i的數量
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]]++;//每找一個子樹就壓進去一批
}
}
while(que.size())
{
jd[que.front()]--;
que.pop();//以u爲根的子樹統計完畢,清空。
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
add(x, y, 1); add(y, x, 1);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
printf("%d\n", ans);
return 0;
}
例2:洛谷 P3806 【模板】點分治1
統計距離k的點是否存在,離線後,和上一個題基本一樣。挨個統計,挨個比較每一個詢問。
下面是ac代碼:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt;
int tmp[N], siz[N], dis[N], mp[N], q[105];
bool jd[N*10], ans[105], vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt++] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
queue<int> que;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
getdis(v, u);
for (int j = 0; j < cnt; j++)
for (int k = 0; k <m; k++)
if (q[k] >= tmp[j])
ans[k] |= jd[q[k]-tmp[j]];
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]] = 1;
}
}
while(que.size())
{
jd[que.front()] = 0;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
for (int i = 0; i < m; i++)
scanf("%d", &q[i]);
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
for (int i = 0; i < m; i++)
{
if (ans[i]) puts("AYE");
else puts("NAY");
}
return 0;
}
例3:洛谷 P4178 Tree
這個題要求我們統計有多少對點之間的距離小於等於k的。這時我們發現,solve函數中更新ans的複雜度不太理想(因爲要對於一個距離g需要累計從0至k-g),不過好在我們可以用樹狀數組維護桶數組jd達到我們的目的。
下面是ac代碼:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[N*100];
bool vis[N];
void change(int x, int y)
{
for (;x <= 100*N-2; x += x & -x) jd[x] += y;
}
ll ask(int x)
{
ll ans = 0;
for (; x; x -= x &-x) ans += jd[x];
return ans;
}
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[++cnt] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
// cout << u << " ::::" << endl;
queue<int> que;
que.push(0);
change(1, 1);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
// cout << v <<":";
getdis(v, u);
// for (int i = 1; i <= cnt; i++)
// cout << tmp[i] << " ";
// cout << endl;
for (int j = 1; j <= cnt; j++)
{
if (tmp[j] > k) continue;
ans += ask(k-tmp[j]+1);
}
for (int j = 1; j <= cnt; j++)
{
que.push(tmp[j]);
change(tmp[j]+1, 1);
}
}
while(que.size())
{
change(que.front()+1, -1);
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
scanf("%d", &k);
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
printf("%d\n", ans);
return 0;
}
例4:洛谷 P2634 [國家集訓隊]聰聰可可
也是比較基礎的問題,就是統計所有距離是3的倍數的對數。然後,值得注意的是,點分治沒有考慮兩個點重合的情況,看題目樣例,我們是需要討論的。好在兩個點重合(dis==0)都符合題目要求,我們之間吧ans加上一個n就ok了。
下面是ac代碼:
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5+5;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, k, ans;
int tmp[N], siz[N], dis[N], mp[N];
int jd[5];
bool vis[N];
int gcd(int a, int b)
{
return b?gcd(b,a%b):a;
}
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[++cnt] = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
getdis(v, u);
}
}
void solve(int u)
{
// cout << u << " ::::" << endl;
queue<int> que;
que.push(0);
jd[0] = 1;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
// cout << v <<":";
getdis(v, u);
// for (int i = 1; i <= cnt; i++)
// cout << tmp[i] << " ";
// cout << endl;
for (int j = 1; j <= cnt; j++)
{
if(tmp[j]%3 == 0) ans += jd[0];
else if (tmp[j]%3 == 1) ans += jd[2];
else ans += jd[1];
}
for (int j = 1; j <= cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j]%3]++;
}
}
while(que.size())
{
jd[que.front()%3]--;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x, y, w); add(y, x, w);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
ans *= 2;
ans += n;
int gg = n *n;
int d = gcd(ans, gg);
printf("%d/%d\n", ans/d, gg/d);
return 0;
}
例5:洛谷 P4149 [IOI2011]Race
比較有意思,求出路徑爲k的點對中,邊數最小的數量。不難,我們在getdis的函數裏同時跑出深度dep。然後在路徑爲k的情況下最小化ans。但是,,但是!這狗題沒說權值範圍。。。。沒法看了1e6的jd數組交了一份re了3個點。隨改成2e7還是re了一個點,在開就mlt了。。。。用了unormap,,t了3個點我擦。。。然後改回jd數組強行hash了一波。。可算是過了。不過這也不是正經解法啊。。於是用unormap吸着氧2.6s險些超時。。。兩份代碼都放一下吧。。可能有正經解法,本憨批不知道。
下面是ac代碼(數組強行hash):
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int mod = N *200;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], jd[N*200], dep[N];
pair<int, int> tmp[N];
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt].first = dep[u];
tmp[cnt++].second = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
dep[v] = dep[u] + 1;
getdis(v, u);
}
}
void solve(int u)
{
queue<pair<int, int> > que;
que.push(make_pair(0, 0));
jd[0] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
dep[v] = 1;
getdis(v, u);
for (int j = 0; j < cnt; j++)
if (k >= tmp[j].second)
ans = min(ans, jd[k-tmp[j].second%mod] + tmp[j].first);
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
jd[tmp[j].second%mod] = min(jd[tmp[j].second%mod], tmp[j].first);
}
}
while(que.size())
{
jd[que.front().second%mod] = inf;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x+1, y+1, w); add(y+1, x+1, w);
}
memset(jd, inf, sizeof(jd));
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
if (ans != inf)
printf("%d\n", ans);
else
puts("-1");
return 0;
}
(吸氧map):
#include <iostream>
#include <cstring>
#include <string>
#include <cmath>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <unordered_map>
#include <queue>
#define max(x, y) ((x)>(y)?(x):(y))
#define ll long long
using namespace std;
const int N = 3e5+5;
const int inf = 0x3f3f3f3f;
int n, m;
int ver[N<<1], he[N], ne[N<<1], e[N<<1];
int tot, rt, sum, cnt, ans = inf, k;
int siz[N], dis[N], mp[N], dep[N];
pair<int, int> tmp[N];
unordered_map<int, int> jd;
bool vis[N];
void add(int x, int y, int w)
{
ver[++tot] = y;
ne[tot] = he[x];
e[tot] = w;
he[x] = tot;
}
void getrt(int u, int f)
{
siz[u] = 1; mp[u] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
if (siz[v] > mp[u]) mp[u] = siz[v];
}
mp[u] = max(mp[u], sum-siz[u]);
if (mp[u] < mp[rt]) rt = u;
}
void getdis(int u, int f)
{
tmp[cnt].first = dep[u];
tmp[cnt++].second = dis[u];
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (v == f || vis[v]) continue;
dis[v] = dis[u] + e[i];
dep[v] = dep[u] + 1;
getdis(v, u);
}
}
void solve(int u)
{
queue<pair<int, int> > que;
que.push(make_pair(0, 0));
jd[0] = 0;
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if (vis[v]) continue;
cnt = 0;
dis[v] = e[i];
dep[v] = 1;
getdis(v, u);
for (int j = 0; j < cnt; j++)
if (k >= tmp[j].second)
{
if (jd.find(k-tmp[j].second) == jd.end()) jd[k-tmp[j].second] = inf;
ans = min(ans, jd[k-tmp[j].second] + tmp[j].first);
}
for (int j = 0; j < cnt; j++)
{
que.push(tmp[j]);
if (jd.find(tmp[j].second) == jd.end()) jd[tmp[j].second] = inf;
jd[tmp[j].second] = min(jd[tmp[j].second], tmp[j].first);
}
}
while(que.size())
{
jd[que.front().second] = inf;
que.pop();
}
}
void divide(int u)
{
vis[u] = jd[0] = 1;
solve(u);
for (int i = he[u]; i; i = ne[i])
{
int v = ver[i];
if(vis[v]) continue;
mp[rt=0] = sum = siz[v];
getrt(v, 0);
getrt(rt, 0);
divide(rt);
}
}
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y, w;
scanf("%d%d%d", &x, &y, &w);
add(x+1, y+1, w); add(y+1, x+1, w);
}
mp[0] = sum = n;
getrt(1, 0);
getrt(rt, 0);
divide(rt);
if (ans != inf)
printf("%d\n", ans);
else
puts("-1");
return 0;
}