[SDOI2017]苹果树 题解

首先,观察题意,可以发现在最长链下再接一个点,结果一定更优。
也就是说,可以免费选一条最长链,之后正常选。
我们枚举选的最长链,然后算出剩下部分的最优解。
有4部分:
1、链上每个点都选一个。
2、链上剩下的部分。
3、链的左面。
4、链的右面。

1可以直接计算。
那么,我们需要先进行树形揹包,然后再通过某方式将其余3个合并。
我们知道,在此问题中,合并2个揹包是\(O(k)\)的;
但3个及以上则是\(O(k^2)\)的,无法承受。
所以,我们只能在计算中就把其中两个合并,这样就只需合并2个了。
可以发现,3和4是正常的树形揹包,而2是一个贪心的问题。
但是,我们没有时间给链上的点排序,再贪心选择。
所以,只能将2转为正常的树形揹包问题。
可以这样想:选x则要选fx,选x的第二个则要选x的第一个。
那么,我们可以把大于1的拆点,拆成1和a-1。连上父子关系。
不难发现,这样我们就只需要3,4两个合并了。所以,只要算出3,4部分,就能\(O(k)\)出解了。

先考虑如何进行树形揹包:

树形揹包有2种实现:dfs合并的和dfs序上dp的。
由于本题每个节点上有多个,所以之前的\(O(nk)\)的分析不适用,复杂度是\(O(nk^2)\),显然超时。
而且,复杂度的瓶颈合并揹包至少是\(O(k^2)\)的,这个是max卷积,不能优化。
所以,这种方法不行。

但是,由于我们不需要知道每个节点的子节点的选择信息(如每个点选择了多少子树节点),
所以,可以考虑dfs序上dp的算法。见博客

这个算法的状态数是\(O(nk)\)的,且相当于逐渐添加,没有合并揹包。
添加的过程是一个多重揹包(由于每个节点上有多个),可以用单调队列优化。这部分可以做到\(O(nk)\)
而且,我们发现:对于问题3,4(即树链的左右),在dfs序上是一段连续区间。这意味着我们可以直接得出3,4的dp值。

对树做先序遍历,可以得到链的右面(后缀)。对树做后遍历,可以得到链的左面(前缀)。

总结下:
首先,拆点。
然后,对树进行先后序遍历,并用多重揹包的单调队列优化算出dp值,\(O(nk)\)
最后,枚举一个叶子,在\(O(k)\)时间算出结果,总共\(O(nk)\)

此外,对于子节点,在3,4两部分都会被算到,要注意排除。
注意卡常。

代码:

#include <stdio.h> 
#define inf 999999999
#define setdp(i, j, x) dp[i * (k + 1) + j] = x
#define getdp(i, j) dp[(i) * (k + 1) + j]
#define getod(i, j) ld[(i) * (k + 1) + j] 
int fr[40010],ne[40010],v[40010],bs = 0,sl[40010],sz[40010],n,k;
void addb(int a, int b) {
    v[bs] = b;
    ne[bs] = fr[a];
    fr[a] = bs++;
}
int xl[40010],si[40010],jl[40010],tm = 0,x1[40010],x2[40010];
void dfs1(int u) {
    x1[u] = tm;
    xl[tm++] = u;
    si[u] = 1;
    for (int i = fr[u]; i != -1; i = ne[i]) {
        jl[v[i]] = jl[u] + sz[v[i]];
        dfs1(v[i]);
        si[u] += si[v[i]];
    }
}
void dfs2(int u) {
    for (int i = fr[u]; i != -1; i = ne[i]) dfs2(v[i]);
    xl[++tm] = u;
    x2[u] = tm;
}
int dl[500010],dz[500010],he = 0,ta = 0,dp[60000010],ld[60000010];
void insert(int i, int x) {
    dz[i] = x;
    while (he < ta && dz[dl[ta - 1]] <= x) ta -= 1;
    dl[ta++] = i;
}
void del(int i) {
    if (he < ta && dl[he] == i) he += 1;
}
int getma() {
    if (he < ta) return dz[dl[he]];
    else return - inf;
}
bool ez[20010],kz[20010];
int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &k);
        bs = 0;
        for (int i = 1; i <= n + n; i++) fr[i] = -1;
        for (int i = 1; i <= n; i++) ez[i] = kz[i] = false;
        for (int i = 1; i <= n; i++) {
            int a;
            scanf("%d%d%d", &a, &sl[i], &sz[i]);
            if (i > 1) addb(a, i);
            ez[a] = true;
        }
        for (int i = 1; i <= n; i++) {
            if (sl[i] > 1) {
                sl[i + n] = sl[i] - 1;
                sz[i + n] = sz[i];
                addb(i, i + n);
                sl[i] = 1;
                kz[i] = true;
            }
        }
        jl[1] = sz[1];
        tm = 0;
        dfs1(1);
        for (int i = tm - 1; i >= 0; i--) {
            he = ta = 0;
            for (int j = 0; j <= k; j++) {
                int u = xl[i],
                ma = getdp(i + si[u], j);
                del(j - sl[u] - 1);
                if (j > 0) insert(j - 1, getdp(i + 1, j - 1) - sz[u] * (j - 1));
                int t = getma() + sz[u] * j;
                if (t > ma) ma = t;
                setdp(i, j, ma);
            }
        }
        for (int i = 0; i <= tm * (k + 1) + k; i++) {
            ld[i] = dp[i];
            dp[i] = 0;
        }
        tm = 0;
        dfs2(1);
        for (int i = 1; i <= tm; i++) {
            he = ta = 0;
            for (int j = 0; j <= k; j++) {
                int u = xl[i],
                ma = getdp(i - si[u], j);
                del(j - sl[u] - 1);
                if (j > 0) insert(j - 1, getdp(i - 1, j - 1) - sz[u] * (j - 1));
                int t = getma() + sz[u] * j;
                if (t > ma) ma = t;
                setdp(i, j, ma);
            }
        }
        int jg = -inf;
        for (int i = 1; i <= n; i++) {
            if (ez[i]) continue;
            int ma = -inf;
            for (int j = 0; j <= k; j++) {
                int t = getod(x1[i] + 1, j);
                if (!kz[i]) t += getdp(x2[i] - 1, k - j);
                else t += getdp(x2[i + n] - 1, k - j);
                if (t > ma) ma = t;
            }
            ma += jl[i];
            if (ma > jg) jg = ma;
        }
        printf("%d\n", jg);
        for (int i = 0; i <= tm * (k + 1) + k; i++) ld[i] = dp[i] = 0;
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章