Codeforces Round #527 (Div. 3) F. Tree with Maximum Cost(樹形dp)

題目鏈接:https://codeforces.com/contest/1092/problem/F

題目大意:給出一棵n個節點的樹,每個節點都有一個權值a,樹邊的長度爲1。現在要你從樹中選出一個節點v,

使得\sum_{i=1}^{n}dis(i,v)*a[i]的值最大,dis(i,v)表示節點 i 到節點 v 的長度。

題目思路:考慮樹形dp,我們令節點1爲根節點。

根據所給的式子\sum_{i=1}^{n}dis(i,v)*a[i],我們可以將其轉換爲,從 i 到 v,每條邊的長度都是a[v],這樣就更加方便維護了。

同時定義dis1[u]=\sum_{v=son[u]}dis(v,u)*a[v]dis2[u] = \sum_{v!=son[u]}dis(v,u)*a[v]

現在來考慮如何維護dis1[u]dis2[u]

對於dis1[u],我們可以直接做一遍dfs,用一個輔助數組sum[u]=\sum_{v=son[u]}a[v]來幫忙維護。

這樣的話dis1[u]=\sum_{v=son[u]}dis1[v]+sum[v],這個式子代表了從 u 的子節點 v 過來到達 u 的節點 x 在到達 v 之後還需要走一條邊,走這條邊的花費就是a[x],那麼總的花費就是sum[v]了。

對於dis2[u],我們可以知道,dis2[root]=0。那麼對於當前非根節點的結點 u 來說,就要分以下兩種情況來進行討論:

1、fa[u] = root:在這種情況下,dis2[u]就是等於dis1[root] - dis1[u] - sum[u] + (sum[root] - sum[u]),因爲此時就是根節點中除了u這棵子樹之外所有子樹的結點到u的花費之和了。

2、fa[u]!=root:在這種情況下,

dis2[u] = dis2[fa] + (sum[root] - sum[fa]) + (sum[fa] - sum[u])+ (dis1[fa] - dis1[u] - sum[u]);

通過這幾個狀態轉移方程再進行一遍dfs就可以維護出所有的dis2[u]了。

這樣最終的結果就是max(dis1[i]+dis2[i])(i=1,2,3,4,...,n)

具體實現看代碼:

#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define clr(a) memset(a,0,sizeof(a))
#define _inf(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<#x<<" " << (x) << "]"<<"\n"
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
const int MX = 2e5 + 5;

int n;
int a[MX];
vector<int>E[MX];
ll dis1[MX], dis2[MX];
ll sum[MX];

void dfs(int u, int fa) {
    sum[u] = a[u];
    for (auto v : E[u]) {
        if (v == fa) continue;
        dfs(v, u);
        sum[u] += sum[v];
        dis1[u] += dis1[v] + sum[v];
    }
}
void dfs2(int u, int fa) {
    if (u != 1) {
        if (fa == 1) {
            dis2[u] = dis1[fa] - dis1[u] - sum[u] + (sum[fa] - sum[u]);
        } else {
            dis2[u] = dis2[fa] + (sum[1] - sum[fa]) + (sum[fa] - sum[u]) + (dis1[fa] - dis1[u] - sum[u]);
        }
    }
    for (auto v : E[u]) {
        if (v == fa) continue;
        dfs2(v, u);
    }
}

int main() {
    // FIN;
    ios::sync_with_stdio(false);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        E[u].pb(v); E[v].pb(u);
    }
    dfs(1, 0);
    dfs2(1, 0);
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = max(ans, (dis1[i] + dis2[i]));
    }
    cout << ans << "\n";
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章