[BJOI2017]樹的難題 點分治 線段樹

題面

[BJOI2017]樹的難題

題解

考慮點分治。
對於每個點,將所有邊按照顏色排序。
那麼只需要考慮如何合併2條鏈。
有2種情況。

  • 合併路徑的接口處2條路徑顏色不同
  • 合併路徑的接口處2條路徑顏色相同

我們分別考慮這2種情況。
維護2棵線段樹,分別表示與當前接口顏色不同和顏色相同。
如果我們遍歷完了一棵子樹,就將這棵子樹的答案加入到顏色相同的線段樹裏面。
如果我們遍歷完了一段顏色,就將第2個線段樹合併到第一個線段樹裏面。
當然更新答案要在上面2個操作之前。
只需要對於當前子樹的每條路徑,在2棵線段樹上分別查詢對應長度區間的答案最大值然後合併即可。
注意從顏色相同線段樹上查詢到的答案合併時需要減一。

// luogu-judger-enable-o2
#include<bits/stdc++.h>
using namespace std;
#define R register int
#define LL long long
#define AC 401000
#define ac 850000
#define inf 9187201950435737472LL

int n, m, rot, lim_l, lim_r, cnt, tinct, top, ss, all, id;
int Head[AC], date[ac], Next[ac], color[ac], tot;
int Size[AC];
LL power[AC], s[AC], f[AC], have[AC], ans = -inf;
bool z[AC];

struct road{
    int x, y, c;
}way[ac];

inline int read()
{
    int x = 0;char c = getchar();bool z_ = false;
    while(c > '9' || c < '0') {if(c == '-') z_ = true; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    if(!z_) return x;
    else return -x;
}

inline void upmin(int &a, int b) {if(b < a) a = b;}
inline void upmax(LL &a, LL b) {if(b > a) a = b;}
inline void add(int f, int w, int S){date[++ tot] = w, Next[tot] = Head[f], Head[f] = tot, color[tot] = S;}
inline bool cmp(road a, road b){return (a.c < b.c);}

struct seg_tree{
    
    LL tree[ac]; int ls[ac], rs[ac], cnt, root;
    
    void init() {cnt = root = 1, tree[1] = tree[0] = -inf, ls[1] = rs[1] = 0;}
    int make() {tree[++ cnt] = -inf, ls[cnt] = rs[cnt] = 0; return cnt;}
    void update(int x) {tree[x] = max(tree[ls[x]], tree[rs[x]]);}
    
    void ins(int &x, int l, int r, int go, LL w)//只有單點修改?
    {
        if(!x) x = make();
        if(l == r){upmax(tree[x], w); return ;}
        int mid = (l + r) >> 1;
        if(go <= mid) ins(ls[x], l, mid, go, w);
        else ins(rs[x], mid + 1, r, go, w);
        update(x);
    }
    
    LL find(int x, int l, int r, int ll, int rr)
    {
        if(!x) return -inf;
        if(l == ll && r == rr) return tree[x];
        int mid = (l + r) >> 1;
        if(rr <= mid) return find(ls[x], l, mid, ll, rr);
        else if(ll > mid) return find(rs[x], mid + 1, r, ll, rr);
        else return max(find(ls[x], l, mid, ll, mid), find(rs[x], mid + 1, r, mid + 1, rr));
    }

}T1, T2;

void merge(){while(id) T1.ins(T1.root, 1, n, have[id], have[id - 1]), id -= 2;}

void getrot(int x, int fa)
{
    f[x] = 0, Size[x] = 1;
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now] || now == fa) continue;
        getrot(now, x);
        upmax(f[x], Size[now]);
        Size[x] += Size[now];
    }
    upmax(f[x], ss - Size[x]);
    if(f[x] < f[rot]) rot = x;
}

void dfs(int x, int fa, int last, int num)//找到當前子樹的每條線段並加入線段樹
{
    //T2.ins(1, 1, n, num, f[x]);
    if(num >= lim_l && num <= lim_r) upmax(ans, f[x]);//不拐彎
    if(num > lim_r) return ; 
    s[++ top] = have[++ id] = f[x], s[++ top] = have[++id] = num;
    int l = max(lim_l - num, 1), r = min(n, lim_r - num);
    if(l <= r) 
    {
        upmax(ans, T2.find(1, 1, n, l, r) + f[x] - power[tinct]);
        upmax(ans, T1.find(1, 1, n, l, r) + f[x]);
    }
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now] || now == fa) continue;
        f[now] = f[x] + ((color[i] == last) ? 0 : power[color[i]]);
        dfs(now, x, color[i], num + 1);
    }
}

void cal(int x)
{
    z[x] = true;
    T1.init(), T2.init();
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now]) continue;
        tinct = color[i], f[now] = power[tinct], dfs(now, x, color[i], 1);
        while(top) T2.ins(T2.root, 1, n, s[top], s[top - 1]), top -= 2;//放到後面再加入防止用到同一棵子樹的點
        if(color[Next[i]] != color[i]) merge(), T2.init();
    }
}

void solve(int x)
{
    //printf("%d\n", x);
    cal(x);
    for(R i = Head[x]; i; i = Next[i])
    {
        int now = date[i];
        if(z[now]) continue;
        rot = 0, f[0] = ss = Size[now];
        getrot(now, 0);
        solve(rot);
    }
}

void pre()
{
    n = read(), m = read(), lim_l = read(), lim_r = read();
    for(R i = 1; i <= m; i ++) power[i] = read();
    for(R i = 1; i < n; i ++) 
    {
        way[++ all].x = read(), way[all].y = read(), way[all].c = read();
        way[all + 1] = way[all], ++all, swap(way[all].x, way[all].y);
    }
    sort(way + 1, way + all + 1, cmp);
    for(R i = 1; i <= all; i ++) add(way[i].x, way[i].y, way[i].c);
}

int main()
{
//  freopen("in.in", "r", stdin);
    pre();
    f[rot] = ss = n;//f[x]表示x的子樹中最重的那棵的重量
    getrot(1, 0);
    solve(rot);
    printf("%lld\n", ans);
//  fclose(stdin);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章