點分治
這裏沒有動態點分治。。
點分治是解決樹上問題的一類算法,很多複雜度能從暴力的降低到.
具體做法是就是求一個樹的重心,樹的重心的性質,其所有的子樹中最大的子樹節點數最少,那麼這個點就是這棵樹的重心,刪去重心後,生成的多棵樹儘可能平衡。
就是說子樹的大小可以的降低,這樣複雜度就降下來了。
如果要達到的複雜度,那麼就需要一個的算法的遍歷樹。我們通過洛谷上面的點分治模板題來了解點分治的一些基本步驟。
P3806
給定一個樹,問距離爲K的點對是否存在。
首先我們分類一下,點對可以分爲這幾類。
1.經過根的點對。
2.不經過根的點對。
關於問題1很好解決直接dfs一次就可以了,然後子樹與子樹組合就可以了。
問題2我們可以通過點分樹,化簡成一個小的子樹的根去解決。
那麼問題就回歸到解決問題1了,這大概就是點分的思想,具體情況可能還要討論一下。
我們開始寫代碼,其中代碼最重要的就是求樹的重心,這個也很簡單這裏不講了。
其實很多點分T的有一種可能就是重心弄錯了。
void findrt(int u, int fa) {
sz[u] = 1, son[u] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
findrt(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], S - sz[u]);
if (son[u] < son[rt]) rt = u;
}
然後開始分治,分治就是以重心爲根,劃分子樹,解決,然後在找子樹的重心。
void divide(int u) {
vis[u] = pd[0] = 1;
solve(u);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
son[0] = n, S = sz[v], rt = 0;
findrt(v, u);
divide(rt);
}
}
這裏面有的寫法可能for循環中減去到根的寫法,具體看寫法了和情況了。
下面我們來看看怎麼解決爲根的樹的問題。
void get_dis(int u, int fa) {
rev[++tot] = dis[u];
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dis[v] = dis[u] + ed[i].w;
get_dis(v, u);
}
}
int solve(int u) {
int c = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
for (int k = 1; k <= m; k++)
if (que[k] >= rev[j])
ok[k] |= pd[que[k] - rev[j]];
for (int j = 1; j <= tot; j++)
if (rev[j] <= 10000000) q[++c] = rev[j], pd[rev[j]] = 1;
}
for (int i = 1; i <= c; i++) pd[q[i]] = 0;
}
其中我們可以看到get_dis函數就是求u的每一個子樹的距離,而且記錄下來,因爲都經過了根,所以要子樹與子樹的組合,並且最後要清空,同時注意要記錄距離爲0的情況。
這樣由於m的範圍不打,所以總的時間複雜度 是可行的。
接下來講一講洛谷上面的幾道簡單點分題。
P2634
這道題和上面一道題思路差不多,這不就是solve函數中,需要統計餘數爲0,1,2的和。累加ans
ll solve(int u) {
ll cc = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++) {
ll x = rev[j] % 3ll;
if (x == 0) cc += mp[0];
else cc += mp[3 - x];
}
for (int j = 1; j <= tot; j++)
mp[rev[j] % 3ll]++;
}
mp[0] = mp[1] = mp[2] = 0;
return cc;
}
注意這裏得到的ans還需要乘2加n纔是最終的分子。
P4149
這道題也和上面的題相似。
做法,在solve中需要一個桶記錄到某個距離邊權的最小值,然後和第一道題類似方法處理。
void solve(int u) {
num[0] = dep[u] = 0;
int c = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
ans = min(ans, num[k - rev[j]] + cc[j]);
for (int j = 1; j <= tot; j++)
if (rev[j] <= k)
q[++c] = rev[j], num[rev[j]] = min(num[rev[j]], cc[j]);
}
for (int i = 1; i <= c; i++) num[q[i]] = inf;
}
CF161D Distance in Tree
這道題,其實哪兒練一練模板還是可以的,也很簡單和上面的差不多,當面 n*m的dp也能過。
P4178
這道題和上面的題基本上一樣,只需要改成統計小於的數字就可以了,至於統計與更新用樹狀數組就可以了。
void solve(int u) {
int c = 0;
add(1, 1);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
tot = 0;
dis[v] = ed[i].w;
get_dis(v, u);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
ans += sum(k - rev[j] + 1);
for (int j = 1; j <= tot; j++)
if (k >= rev[j])
q[++c] = rev[j], add(rev[j] + 1, 1);
}
add(1, -1);
for (int i = 1; i <= c; i++) add(q[i] + 1, -1);
P2664
這道題就很難了,想不到,看題解。
大概做法就是ans數組記錄每一個點的答案,然後分成每一個點分樹求對樹中每一個點的貢獻。
一個子樹中對於根的貢獻就是,每一個顏色從他到根的路徑上第一次出現那麼,他對根的貢獻就是他的子樹的大小。
對於其他點。
到u到根的這段路徑上的不同顏色爲num,那麼對於點u的貢獻就是,這個下來可以自己想一想。
然後就是點分樹上除去u所在子樹的其他點的不同顏色,其實就是
因此這樣就可以寫了。但不過你思路明白了可能代碼寫起來還不好寫。
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
#define SZ(x) ((int)x.size())
#define ll long long
const int maxn = 100000 + 10;
const ll inf = 1e18;
struct edge {
int u, v, nxt;
} ed[maxn << 1];
int head[maxn << 1], cnt;
void add_e(int u, int v) {
ed[++cnt] = edge{u, v, head[u]};
head[u] = cnt;
}
int sz[maxn], son[maxn], vis[maxn], col[maxn], mx, rt, n, m, S;
void findrt(int u, int fa) {
sz[u] = 1, son[u] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
findrt(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], S - sz[u]);
if (son[u] < son[rt]) rt = u;
}
ll ans[maxn], c[maxn], mp[maxn], sum, num, tmp;
void dfs1(int u, int fa) {
mp[col[u]]++, sz[u] = 1;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v] || v == fa) continue;
dfs1(v, u);
sz[u] += sz[v];
}
if (mp[col[u]] == 1) {
sum += sz[u];
c[col[u]] += sz[u];
}
mp[col[u]]--;
}
void dfs(int u, int fa, int f) {
mp[col[u]]++;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dfs(v, u, f);
}
if (mp[col[u]] == 1) {
sum += sz[u] * f;
c[col[u]] += sz[u] * f;
}
mp[col[u]]--;
}
void dfs2(int u, int fa) {
mp[col[u]]++;
if (mp[col[u]] == 1) {
sum -= c[col[u]];
num++;
}
ans[u] += sum + num * tmp;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
dfs2(v, u);
}
if (mp[col[u]] == 1) {
sum += c[col[u]];
num--;
}
mp[col[u]]--;
}
void clear(int u, int fa) {
c[col[u]] = mp[col[u]] = 0;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (v == fa || vis[v]) continue;
clear(v, u);
}
}
void solve(int u) {
dfs1(u, 0);
ans[u] += sum;
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
mp[col[u]]++, sum -= sz[v], c[col[u]] -= sz[v];
dfs(v, u, -1);
mp[col[u]]--;
tmp = sz[u] - sz[v];
dfs2(v, u);
mp[col[u]]++, sum += sz[v], c[col[u]] += sz[v];
dfs(v, u, 1);
mp[col[u]]--;
}
sum = num = 0;
clear(u, 0);
}
void divide(int u) {
vis[u] = 1;
solve(u);
for (int i = head[u]; i; i = ed[i].nxt) {
int v = ed[i].v;
if (vis[v]) continue;
S = sz[v];
son[rt = 0] = n;
findrt(v, 0);
divide(rt);
}
}
int main() {
n = read();
for (int i = 1; i <= n; i++)
col[i] = read();
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
add_e(u, v);
add_e(v, u);
}
S = son[rt = 0] = n;
findrt(1, 0);
divide(rt);
for (int i = 1; i <= n; i++) {
cout << ans[i] << endl;
}
return 0;
}