ZOJ-3949 Edge to the Root(樹形dp)

題目鏈接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3949

題目大意:有一棵以結點1爲根節點且邊權值爲1的樹,現在你可以從結點1向樹中的某一個點x連一條邊。現在要使得樹上除根節點1以外的點到根節點1的距離和最小,問結點1應該和哪個結點連邊。

題目思路:通過畫圖,我們可以知道結點1和結點x連邊之後,只會對結點1到結點x的鏈上的點及其子樹產生影響,同時只會影響到深度爲dep[x]/2 ~ dep[x]的結點及其子樹。

那麼連邊之後,樹上的結點要到達結點1的最短路就會分爲兩類,一類是直接前往結點1,另一類是先走到結點x再走到結點1。我們可以通過兩遍dfs處理出以下的一些信息:

sz[u]:結點u的子樹大小;

dep[u]:結點u的深度;

fa[u][i]:結點u的祖先;

d1[u]:結點u的子樹中所有結點到達結點u的距離之和;

d2[u]:樹上任意一個結點到達結點u的距離之和;

前面4個都是很容易求的,現在講一下d2[u]該如何求。顯然當u=1時,d2[u]=d1[u];當u != 1時,我們就可以藉助u的父親結點來推出d2[u]的值,d2[u] =(d2[fa]-sz[u]*1)+(sz[1]-sz[u])*1,我們已經知道d2[fa]表示樹上任意一點到達結點fa的距離之和,由於現在是要求d2[u],那麼結點u的子樹中的點就不必再走到結點fa了,就減少了sz[u]*1的距離;但是除結點u的子樹以外的點需要走到fa之後,再走到結點u,所以就得再增加(sz[1]-sz[u])*1的距離。這樣就能求出d2[u]的值了。

現在預處理完這些值,接下來就能對答案進行求解了。上面說了,樹上的結點要到達結點1的最短路就會分爲兩類,一類是直接前往結點1,第二類是先走到結點x再走到結點1。且只有深度在dep[x]/2 ~ dep[x]的結點及其子樹會是第二類情況,現在假設向結點x連邊之後會影響到的最上面的點爲par。

那麼第一類的點對答案的貢獻就是res1=d2[1]-d1[par]-sz[par]*dep[par],表示除了par及其子樹外的結點都直接前往結點1的距離之和。

第二類點對答案的貢獻就是res2=d2[x] - (d2[par]-d1[par]+(n-sz[par])*dis)+sz[par],dis表示結點par到結點x的距離。這個式子表示par的子樹內所有的點到達結點x的距離之和再加上通過結點1到結點x的邊到達結點1的距離之和。

那麼結點1向結點x連邊之後,所有結點到達結點1的最小距離之和爲res1+res2

par我們可以通過類似求lca的方法倍增求出來,剩下的部分在前面就預處理好了,只需要O(1)即可求出,所以我們就可以枚舉x求出最終的答案了。

具體實現看代碼:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
typedef vector<int> VI;
const int inf = 0x3f3f3f3f;
const double eps = 1e-6;
const int MX = 2e5 + 5;

int n, _;
struct edge {
	int v, nxt;
} E[MX << 1];
int head[MX], tot;
int sz[MX], dep[MX], fa[MX][22];
ll d1[MX], d2[MX];
void init() {
	clr(d1); clr(d2);
	memset(head, -1, sizeof(head));
	tot = 0;
}
void add_edge(int u, int v) {
	E[tot].v = v; E[tot].nxt = head[u];
	head[u] = tot++;
}
void dfs1(int u, int pre, int d) {
	dep[u] = d;
	sz[u] = 1; fa[u][0] = pre;
	for (int i = 1; i <= 20; i++)
		fa[u][i] = fa[fa[u][i - 1]][i - 1];
	for (int i = head[u]; ~i; i = E[i].nxt) {
		int v = E[i].v;
		if (v == pre) continue;
		dfs1(v, u, d + 1);
		d1[u] += d1[v] + sz[v];
		sz[u] += sz[v];
	}
}
void dfs2(int u, int pre) {
	if (u == 1) d2[u] = d1[u];
	else d2[u] = d2[pre] - sz[u] + (sz[1] - sz[u]);
	for (int i = head[u]; ~i; i = E[i].nxt) {
		int v = E[i].v;
		if (v == pre) continue;
		dfs2(v, u);
	}
}
int Find(int x, int dis) {
	for (int i = 20; i >= 0; i--) {
		if ((dis >> i) & 1) x = fa[x][i];
	}
	return x;
}

int main() {
	//FIN;
	for (scanf("%d", &_); _; _--) {
		scanf("%d", &n);
		init();
		for (int i = 1; i < n; i++) {
			int u, v;
			scanf("%d%d", &u, &v);
			add_edge(u, v); add_edge(v, u);
		}
		dfs1(1, 0, 0); dfs2(1, 0);
		ll ans = d1[1];
		for (int i = 2; i <= n; i++) {
			int dis = (dep[i] - dep[i] / 2 - 1);
			int par = Find(i, dis);
			ll res = d2[par] - d1[par] + 1ll * (n - sz[par]) * dis;
			ll sub_dis = d2[i] - res + sz[par];
			ll all_dis = d2[1] - d1[par] - 1ll * sz[par] * dep[par];
			ans = min(ans, sub_dis + all_dis);
		}
		cout << ans << endl;
	}
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章