[SDOI2017]蘋果樹 題解

首先,觀察題意,可以發現在最長鏈下再接一個點,結果一定更優。
也就是說,可以免費選一條最長鏈,之後正常選。
我們枚舉選的最長鏈,然後算出剩下部分的最優解。
有4部分:
1、鏈上每個點都選一個。
2、鏈上剩下的部分。
3、鏈的左面。
4、鏈的右面。

1可以直接計算。
那麼,我們需要先進行樹形揹包,然後再通過某方式將其餘3個合併。
我們知道,在此問題中,合併2個揹包是\(O(k)\)的;
但3個及以上則是\(O(k^2)\)的,無法承受。
所以,我們只能在計算中就把其中兩個合併,這樣就只需合併2個了。
可以發現,3和4是正常的樹形揹包,而2是一個貪心的問題。
但是,我們沒有時間給鏈上的點排序,再貪心選擇。
所以,只能將2轉爲正常的樹形揹包問題。
可以這樣想:選x則要選fx,選x的第二個則要選x的第一個。
那麼,我們可以把大於1的拆點,拆成1和a-1。連上父子關係。
不難發現,這樣我們就只需要3,4兩個合併了。所以,只要算出3,4部分,就能\(O(k)\)出解了。

先考慮如何進行樹形揹包:

樹形揹包有2種實現:dfs合併的和dfs序上dp的。
由於本題每個節點上有多個,所以之前的\(O(nk)\)的分析不適用,複雜度是\(O(nk^2)\),顯然超時。
而且,複雜度的瓶頸合併揹包至少是\(O(k^2)\)的,這個是max卷積,不能優化。
所以,這種方法不行。

但是,由於我們不需要知道每個節點的子節點的選擇信息(如每個點選擇了多少子樹節點),
所以,可以考慮dfs序上dp的算法。見博客

這個算法的狀態數是\(O(nk)\)的,且相當於逐漸添加,沒有合併揹包。
添加的過程是一個多重揹包(由於每個節點上有多個),可以用單調隊列優化。這部分可以做到\(O(nk)\)
而且,我們發現:對於問題3,4(即樹鏈的左右),在dfs序上是一段連續區間。這意味着我們可以直接得出3,4的dp值。

對樹做先序遍歷,可以得到鏈的右面(後綴)。對樹做後遍歷,可以得到鏈的左面(前綴)。

總結下:
首先,拆點。
然後,對樹進行先後序遍歷,並用多重揹包的單調隊列優化算出dp值,\(O(nk)\)
最後,枚舉一個葉子,在\(O(k)\)時間算出結果,總共\(O(nk)\)

此外,對於子節點,在3,4兩部分都會被算到,要注意排除。
注意卡常。

代碼:

#include <stdio.h> 
#define inf 999999999
#define setdp(i, j, x) dp[i * (k + 1) + j] = x
#define getdp(i, j) dp[(i) * (k + 1) + j]
#define getod(i, j) ld[(i) * (k + 1) + j] 
int fr[40010],ne[40010],v[40010],bs = 0,sl[40010],sz[40010],n,k;
void addb(int a, int b) {
    v[bs] = b;
    ne[bs] = fr[a];
    fr[a] = bs++;
}
int xl[40010],si[40010],jl[40010],tm = 0,x1[40010],x2[40010];
void dfs1(int u) {
    x1[u] = tm;
    xl[tm++] = u;
    si[u] = 1;
    for (int i = fr[u]; i != -1; i = ne[i]) {
        jl[v[i]] = jl[u] + sz[v[i]];
        dfs1(v[i]);
        si[u] += si[v[i]];
    }
}
void dfs2(int u) {
    for (int i = fr[u]; i != -1; i = ne[i]) dfs2(v[i]);
    xl[++tm] = u;
    x2[u] = tm;
}
int dl[500010],dz[500010],he = 0,ta = 0,dp[60000010],ld[60000010];
void insert(int i, int x) {
    dz[i] = x;
    while (he < ta && dz[dl[ta - 1]] <= x) ta -= 1;
    dl[ta++] = i;
}
void del(int i) {
    if (he < ta && dl[he] == i) he += 1;
}
int getma() {
    if (he < ta) return dz[dl[he]];
    else return - inf;
}
bool ez[20010],kz[20010];
int main() {
    int T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &k);
        bs = 0;
        for (int i = 1; i <= n + n; i++) fr[i] = -1;
        for (int i = 1; i <= n; i++) ez[i] = kz[i] = false;
        for (int i = 1; i <= n; i++) {
            int a;
            scanf("%d%d%d", &a, &sl[i], &sz[i]);
            if (i > 1) addb(a, i);
            ez[a] = true;
        }
        for (int i = 1; i <= n; i++) {
            if (sl[i] > 1) {
                sl[i + n] = sl[i] - 1;
                sz[i + n] = sz[i];
                addb(i, i + n);
                sl[i] = 1;
                kz[i] = true;
            }
        }
        jl[1] = sz[1];
        tm = 0;
        dfs1(1);
        for (int i = tm - 1; i >= 0; i--) {
            he = ta = 0;
            for (int j = 0; j <= k; j++) {
                int u = xl[i],
                ma = getdp(i + si[u], j);
                del(j - sl[u] - 1);
                if (j > 0) insert(j - 1, getdp(i + 1, j - 1) - sz[u] * (j - 1));
                int t = getma() + sz[u] * j;
                if (t > ma) ma = t;
                setdp(i, j, ma);
            }
        }
        for (int i = 0; i <= tm * (k + 1) + k; i++) {
            ld[i] = dp[i];
            dp[i] = 0;
        }
        tm = 0;
        dfs2(1);
        for (int i = 1; i <= tm; i++) {
            he = ta = 0;
            for (int j = 0; j <= k; j++) {
                int u = xl[i],
                ma = getdp(i - si[u], j);
                del(j - sl[u] - 1);
                if (j > 0) insert(j - 1, getdp(i - 1, j - 1) - sz[u] * (j - 1));
                int t = getma() + sz[u] * j;
                if (t > ma) ma = t;
                setdp(i, j, ma);
            }
        }
        int jg = -inf;
        for (int i = 1; i <= n; i++) {
            if (ez[i]) continue;
            int ma = -inf;
            for (int j = 0; j <= k; j++) {
                int t = getod(x1[i] + 1, j);
                if (!kz[i]) t += getdp(x2[i] - 1, k - j);
                else t += getdp(x2[i + n] - 1, k - j);
                if (t > ma) ma = t;
            }
            ma += jl[i];
            if (ma > jg) jg = ma;
        }
        printf("%d\n", jg);
        for (int i = 0; i <= tm * (k + 1) + k; i++) ld[i] = dp[i] = 0;
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章