CDOJ 1351(树形DP)

题意&思路:

题意:带权边树上,在起点0出发,初始值为V ,求最多能访问的节点数目?

如果数据小一点,我们可以考虑 dp[u][i] ,表示u节点及其子树花费i最多能够访问的节点数,这个是很好做的,树上揹包嘛。

但是题目多询问,于是考虑 dp[u][i][o] ,表示当前节点u及其子树,访问了 i 个节点,并且是否停留在u点的最小花费。

这样的话,会有以下几种情况:
U 为父节点,v 为子节点
1. 访问了v 子树j个点 并且 在v点
访问了u已处理子树i个点 并且不在u点
最后从v回到 u已处理子树中的某一个点
update(dp[u][i+j][0],dp[u][i][0]+dp[v][j][1]+2d);
2. 访问了v 子树j个点 并且不在v点
访问了u已处理子树i个点 并且在u点
最后回到了v子树中的某一个点
update(dp[u][i+j][0],dp[u][i][1]+dp[v][j][0]+d);
3. 访问了v 子树j个点 并且在v点
访问了u已处理子树i个点 并且在u点
但是回到了v点(也就是停在了v点)
update(dp[u][i+j][0],dp[u][i][1]+dp[v][j][1]+d);
4. 问了v 子树j个点 并且在v点
访问了u已处理子树i个点 并且在u点
最后回到了u点
update(dp[u][i+j][1],dp[u][i][1]+dp[v][j][1]+2d);

代码:

#include <bits/stdc++.h>
#define PB push_back
#define FT first
#define SD second
#define MP make_pair
#define INF 0x3f3f3f3f
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int>  P;
const int N = 505,MOD = 7+1e9;
vector<P> G[N];
int dp[N][N][2];
void init(int n)
{
    for(int i = 0;i < n;i ++)
    {
        dp[i][0][0] = dp[i][0][1] = 0;
        for(int j = 1;j <= n;j ++) dp[i][j][0] = dp[i][j][1] = INF;
    }
}
void update(int& x, int y)
{
    x = min(x, y);
}
int dfs(int u, int fa)
{
    dp[u][1][1] = 0;
    int sum = 1;
    for(auto it : G[u])
    {
        int v = it.FT, d = it.SD;
        if(v == fa) continue;
        int sz = dfs(v, u);
        for(int i = sum;i > 0;i --)
        {
            for(int j = sz;j > 0;j --)
            {
                update(dp[u][i+j][0], dp[u][i][0] + dp[v][j][1] + 2 * d);
                update(dp[u][i+j][0], dp[u][i][1] + dp[v][j][0] +     d);
                update(dp[u][i+j][0], dp[u][i][1] + dp[v][j][1] +     d);
                update(dp[u][i+j][1], dp[u][i][1] + dp[v][j][1] + 2 * d);                
            }
        }
        sum += sz;
    }
    return sum;
}
int main()
{
    int n;
    scanf("%d", &n);
    init(n);
    for(int i = 1;i < n;i ++)
    {
        int u, v, d;
        scanf("%d%d%d", &u, &v, &d);
        G[u].PB({v,d});
        G[v].PB({u,d});
    }
    dfs(0, -1);
    int Q;
    scanf("%d", &Q);
    while(Q --)
    {
        int x, ans = 0;
        scanf("%d", &x);
        for(int i = n;i >= 1;i --)
        {
            if(dp[0][i][0] <= x)
            {
                ans = i;
                break;
            }
        }
        printf("%d\n", max(1, ans));
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章