題目鏈接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3949
題目大意:有一棵以結點1爲根節點且邊權值爲1的樹,現在你可以從結點1向樹中的某一個點x連一條邊。現在要使得樹上除根節點1以外的點到根節點1的距離和最小,問結點1應該和哪個結點連邊。
題目思路:通過畫圖,我們可以知道結點1和結點x連邊之後,只會對結點1到結點x的鏈上的點及其子樹產生影響,同時只會影響到深度爲dep[x]/2 ~ dep[x]的結點及其子樹。
那麼連邊之後,樹上的結點要到達結點1的最短路就會分爲兩類,一類是直接前往結點1,另一類是先走到結點x再走到結點1。我們可以通過兩遍dfs處理出以下的一些信息:
:結點u的子樹大小;
:結點u的深度;
:結點u的祖先;
:結點u的子樹中所有結點到達結點u的距離之和;
:樹上任意一個結點到達結點u的距離之和;
前面4個都是很容易求的,現在講一下該如何求。顯然當u=1時,;當u != 1時,我們就可以藉助u的父親結點來推出的值,,我們已經知道表示樹上任意一點到達結點fa的距離之和,由於現在是要求,那麼結點u的子樹中的點就不必再走到結點fa了,就減少了的距離;但是除結點u的子樹以外的點需要走到fa之後,再走到結點u,所以就得再增加的距離。這樣就能求出的值了。
現在預處理完這些值,接下來就能對答案進行求解了。上面說了,樹上的結點要到達結點1的最短路就會分爲兩類,一類是直接前往結點1,第二類是先走到結點x再走到結點1。且只有深度在dep[x]/2 ~ dep[x]的結點及其子樹會是第二類情況,現在假設向結點x連邊之後會影響到的最上面的點爲par。
那麼第一類的點對答案的貢獻就是,表示除了par及其子樹外的結點都直接前往結點1的距離之和。
第二類點對答案的貢獻就是,dis表示結點par到結點x的距離。這個式子表示par的子樹內所有的點到達結點x的距離之和再加上通過結點1到結點x的邊到達結點1的距離之和。
那麼結點1向結點x連邊之後,所有結點到達結點1的最小距離之和爲。
par我們可以通過類似求lca的方法倍增求出來,剩下的部分在前面就預處理好了,只需要O(1)即可求出,所以我們就可以枚舉x求出最終的答案了。
具體實現看代碼:
#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>pll;
typedef pair<int, int>pii;
typedef vector<int> VI;
const int inf = 0x3f3f3f3f;
const double eps = 1e-6;
const int MX = 2e5 + 5;
int n, _;
struct edge {
int v, nxt;
} E[MX << 1];
int head[MX], tot;
int sz[MX], dep[MX], fa[MX][22];
ll d1[MX], d2[MX];
void init() {
clr(d1); clr(d2);
memset(head, -1, sizeof(head));
tot = 0;
}
void add_edge(int u, int v) {
E[tot].v = v; E[tot].nxt = head[u];
head[u] = tot++;
}
void dfs1(int u, int pre, int d) {
dep[u] = d;
sz[u] = 1; fa[u][0] = pre;
for (int i = 1; i <= 20; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (v == pre) continue;
dfs1(v, u, d + 1);
d1[u] += d1[v] + sz[v];
sz[u] += sz[v];
}
}
void dfs2(int u, int pre) {
if (u == 1) d2[u] = d1[u];
else d2[u] = d2[pre] - sz[u] + (sz[1] - sz[u]);
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (v == pre) continue;
dfs2(v, u);
}
}
int Find(int x, int dis) {
for (int i = 20; i >= 0; i--) {
if ((dis >> i) & 1) x = fa[x][i];
}
return x;
}
int main() {
//FIN;
for (scanf("%d", &_); _; _--) {
scanf("%d", &n);
init();
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add_edge(u, v); add_edge(v, u);
}
dfs1(1, 0, 0); dfs2(1, 0);
ll ans = d1[1];
for (int i = 2; i <= n; i++) {
int dis = (dep[i] - dep[i] / 2 - 1);
int par = Find(i, dis);
ll res = d2[par] - d1[par] + 1ll * (n - sz[par]) * dis;
ll sub_dis = d2[i] - res + sz[par];
ll all_dis = d2[1] - d1[par] - 1ll * sz[par] * dep[par];
ans = min(ans, sub_dis + all_dis);
}
cout << ans << endl;
}
return 0;
}