CF1249F Maximum Weight Subset

CF1249F Maximum Weight Subset

题意

给出一个树,每个点有权值

求最大权值子集,且子集任意一对点的距离大于kk


思路

  • 明显的树上 dpdp

  • dp[i][j]dp[i][j] 表示以 ii 为根的子树,最小深度为 jj 的最大符合条件的子集

  • img

  • 对于一个点,以 11 为根,以 44 为子树为根,显然有两种情况

    • 子集包含 44 那么只需要加上 44 的每个子树深度为 kk 的最大子集即可

    • for (auto it : E[now]) {
          if (it == fa)
              continue;
          dp[now][0] += dp[it][k]; //与it距离为k,与now距离k+1
      }
      
    • 若不取 44 则需要枚举深度为 jj 的点所在的子树,其他子树上的子集与该点距离应大于 kk

    • for (int i = 1; i < n; i++) {
          for (auto it : E[now]) {//枚举距离为j的点
              if (it == fa)
                  continue;
              int cnt = dp[it][i - 1];
              for (auto other : E[now]) {
                  if (other == it || other == fa)
                      continue;
                  cnt += dp[other][max(i - 1, k - i)];//至少距离为i-1
              }
              dp[now][i] = max(dp[now][i], cnt);
          }
      }
      
  • 复杂度

  • 这个算法很容易估算为 O(n4)O(n^4) ,但其实是 O(n3)O(n^3)

  • 因为在算复杂度时,大家往往会忽略 dfsdfs 在枚举子树的过程是 O(n)O(n) 的,因为一棵树是只有n1n-1条边


代码

#include <bits\stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 205;
int a[maxn];
vector<int> E[maxn];
int n, k;
int dp[maxn][maxn];
void dfs(int now, int fa)
{
    dp[now][0] = a[now];
    for (auto it : E[now]) {
        if (it == fa)
            continue;
        dfs(it, now);
    }
    for (auto it : E[now]) {
        if (it == fa)
            continue;
        dp[now][0] += dp[it][k];
    }
    for (int i = 1; i < n; i++) {
        for (auto it : E[now]) {
            if (it == fa)
                continue;
            int cnt = dp[it][i - 1];
            for (auto other : E[now]) {
                if (other == it || other == fa)
                    continue;
                cnt += dp[other][max(i - 1, k - i)];
            }
            dp[now][i] = max(dp[now][i], cnt);
        }
    }
    for (int i = n - 1; i >= 0; i--)
        dp[now][i - 1] = max(dp[now][i - 1], dp[now][i]);
}

int main()
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    int u, v;
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &u, &v);
        E[u].push_back(v);
        E[v].push_back(u);
    }
    dfs(1, 0);
    printf("%d\n", dp[1][0]);
    return 0;
}

back(v);
E[v].push_back(u);
}
dfs(1, 0);
printf("%d\n", dp[1][0]);
return 0;
}


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章