Codeforces Round #527 (Div. 3) F. Tree with Maximum Cost(树形dp)

题目链接:https://codeforces.com/contest/1092/problem/F

题目大意:给出一棵n个节点的树,每个节点都有一个权值a,树边的长度为1。现在要你从树中选出一个节点v,

使得\sum_{i=1}^{n}dis(i,v)*a[i]的值最大,dis(i,v)表示节点 i 到节点 v 的长度。

题目思路:考虑树形dp,我们令节点1为根节点。

根据所给的式子\sum_{i=1}^{n}dis(i,v)*a[i],我们可以将其转换为,从 i 到 v,每条边的长度都是a[v],这样就更加方便维护了。

同时定义dis1[u]=\sum_{v=son[u]}dis(v,u)*a[v]dis2[u] = \sum_{v!=son[u]}dis(v,u)*a[v]

现在来考虑如何维护dis1[u]dis2[u]

对于dis1[u],我们可以直接做一遍dfs,用一个辅助数组sum[u]=\sum_{v=son[u]}a[v]来帮忙维护。

这样的话dis1[u]=\sum_{v=son[u]}dis1[v]+sum[v],这个式子代表了从 u 的子节点 v 过来到达 u 的节点 x 在到达 v 之后还需要走一条边,走这条边的花费就是a[x],那么总的花费就是sum[v]了。

对于dis2[u],我们可以知道,dis2[root]=0。那么对于当前非根节点的结点 u 来说,就要分以下两种情况来进行讨论:

1、fa[u] = root:在这种情况下,dis2[u]就是等于dis1[root] - dis1[u] - sum[u] + (sum[root] - sum[u]),因为此时就是根节点中除了u这棵子树之外所有子树的结点到u的花费之和了。

2、fa[u]!=root:在这种情况下,

dis2[u] = dis2[fa] + (sum[root] - sum[fa]) + (sum[fa] - sum[u])+ (dis1[fa] - dis1[u] - sum[u]);

通过这几个状态转移方程再进行一遍dfs就可以维护出所有的dis2[u]了。

这样最终的结果就是max(dis1[i]+dis2[i])(i=1,2,3,4,...,n)

具体实现看代码:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define clr(a) memset(a,0,sizeof(a))
#define _inf(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<#x<<" " << (x) << "]"<<"\n"
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
const int MX = 2e5 + 5;

int n;
int a[MX];
vector<int>E[MX];
ll dis1[MX], dis2[MX];
ll sum[MX];

void dfs(int u, int fa) {
    sum[u] = a[u];
    for (auto v : E[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sum[u] += sum[v];
        dis1[u] += dis1[v] + sum[v];
    }
}
void dfs2(int u, int fa) {
    if (u != 1) {
        if (fa == 1) {
            dis2[u] = dis1[fa] - dis1[u] - sum[u] + (sum[fa] - sum[u]);
        } else {
            dis2[u] = dis2[fa] + (sum[1] - sum[fa]) + (sum[fa] - sum[u]) + (dis1[fa] - dis1[u] - sum[u]);
        }
    }
    for (auto v : E[u]) {
        if (v == fa) continue;
        dfs2(v, u);
    }
}

int main() {
    // FIN;
    ios::sync_with_stdio(false);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        E[u].pb(v); E[v].pb(u);
    }
    dfs(1, 0);
    dfs2(1, 0);
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = max(ans, (dis1[i] + dis2[i]));
    }
    cout << ans << "\n";
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章