[BJOI2017]樹的難題【點分治+線段樹】

題目鏈接


  很容易想到的是,我們可以對首先對每個點連出去的邊按照邊顏色進行排序,這樣就可以保證一段相同的是會在一起出現的,然後我們可以進行點分治,再利用線段樹維護深度對應的最大權,我們需要兩棵線段樹,一棵維護不同值,一棵維護相同值,另外要保證我們每次初始化需要到位,不然很有可能被卡70或80pts就比較的難受了。

8 2 4 4
-5 3
1 2 2
1 5 1
1 8 2
2 3 1
2 4 1
5 6 2
5 7 1
ans:-4
8 4 1 5
-7 9 6 1
1 2 1
1 3 2
1 4 1
2 5 1
5 6 2
3 7 1
3 8 3
ans:17
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define Big_INF 0x3f3f3f3f3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
#define MP(a, b) make_pair(a, b)
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 2e5 + 7;
int N, M, L, R;
ll c[maxN], ans;
vector<pair<int, int>> E[maxN];
bool vis[maxN];
int all, root, maxx, siz[maxN], son[maxN];
void findroot(int u, int fa)
{
    siz[u] = 1; son[u] = 0;
    int v;
    for(pair<int, int> i : E[u])
    {
        v = i.second;
        if(vis[v] || v == fa) continue;
        findroot(v, u);
        siz[u] += siz[v];
        son[u] = max(son[u], siz[v]);
    }
    son[u] = max(son[u], all - siz[u]);
    if(maxx > son[u]) { maxx = son[u]; root = u; }
}
struct BIT_Tree
{
    ll tree[maxN << 2];
    bool clear[maxN << 2];
    void pushdown(int rt)
    {
        if(clear[rt])
        {
            clear[lsn] = clear[rsn] = true;
            tree[lsn] = tree[rsn] = -Big_INF;
            clear[rt] = false;
        }
    }
    void pushup(int rt) { tree[rt] = max(tree[lsn], tree[rsn]); }
    void update(int rt, int l, int r, int pos, ll val)
    {
        if(l == r) { tree[rt] = max(tree[rt], val); return; }
        pushdown(rt);
        int mid = HalF;
        if(pos <= mid) update(Lson, pos, val);
        else update(Rson, pos, val);
        pushup(rt);
    }
    ll query(int rt, int l, int r, int ql, int qr)
    {
        if(ql <= l && qr >= r) return tree[rt];
        pushdown(rt);
        int mid = HalF;
        if(qr <= mid) return query(QL);
        else if(ql > mid) return query(QR);
        else return max(query(QL), query(QR));
    }
} t[2]; //old, same
int deep[maxN], max_deep, Stap[maxN], Stop;
ll dis[maxN], sav[maxN];
void dfs(int u, int fa, int old_col)
{
    Stap[++Stop] = u;
    max_deep = max(max_deep, deep[u]);
    int v;
    for(pair<int, int> i : E[u])
    {
        v = i.second;
        if(vis[v] || v == fa) continue;
        dis[v] = dis[u] + (i.first == old_col ? 0 : c[i.first]);
        deep[v] = deep[u] + 1;
        if(deep[v] <= R) dfs(v, u, i.first);
    }
}
void Divide(int u)
{
    vis[u] = true; t[0].clear[1] = true; t[1].clear[1] = true; t[0].tree[1] = -Big_INF; t[1].tree[1] = -Big_INF;
    int v, las = -1, old_max_deep = 0;
    for(pair<int, int> i : E[u])
    {
        v = i.second;
        if(vis[v]) continue;
        max_deep = 0; deep[v] = 1; dis[v] = c[i.first];
        dfs(v, u, i.first);
        if(las ^ i.first)
        {
            old_max_deep = min(R - 1, old_max_deep);
            for(int j=1; j<=old_max_deep; j++)
            {
                t[0].update(1, 1, N, j, sav[j]);
                sav[j] = -Big_INF;
            }
            old_max_deep = max_deep; t[1].clear[1] = true; t[1].tree[1] = -Big_INF;
            for(int j=1, id; j<=Stop; j++)
            {
                id = Stap[j];
                if(deep[id] >= L) ans = max(ans, dis[id]);
                if(deep[id] < R) ans = max(ans, dis[id] + t[0].query(1, 1, N, L - deep[id], R - deep[id]));
            }
            while(Stop)
            {
                int id = Stap[Stop--];
                if(deep[id] < R && sav[deep[id]] < dis[id])
                {
                    t[1].update(1, 1, N, deep[id], dis[id]);
                    sav[deep[id]] = dis[id];
                }
            }
            las = i.first;
        }
        else
        {
            for(int j=1, id; j<=Stop; j++)
            {
                id = Stap[j];
                if(deep[id] >= L) ans = max(ans, dis[id]);
                if(deep[id] < R) ans = max(ans, dis[id] + max(t[0].query(1, 1, N, L - deep[id], R - deep[id]), t[1].query(1, 1, N, L - deep[id], R - deep[id]) - c[i.first]));
            }
            old_max_deep = max(old_max_deep, max_deep);
            while(Stop)
            {
                int id = Stap[Stop--];
                if(deep[id] < R && sav[deep[id]] < dis[id])
                {
                    t[1].update(1, 1, N, deep[id], dis[id]);
                    sav[deep[id]] = dis[id];
                }
            }
        }
    }
    old_max_deep = min(old_max_deep, R - 1);
    for(int i=1; i<=old_max_deep; i++) sav[i] = -Big_INF;
    int totsiz = all;
    for(pair<int, int> i : E[u])
    {
        v = i.second;
        if(vis[v]) continue;
        all = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
        maxx = INF;
        findroot(v, u);
        Divide(root);
    }
}
inline void init()
{
    Stop = 0; ans = -Big_INF;
    for(int i=1; i<=N; i++) { vis[i] = false; E[i].clear(); sav[i] = -Big_INF; }
}
int main()
{
    scanf("%d%d%d%d", &N, &M, &L, &R);
    init();
    for(int i=1; i<=M; i++) scanf("%lld", &c[i]);
    for(int i=1, u, v, w; i<N; i++)
    {
        scanf("%d%d%d", &u, &v, &w);
        E[u].push_back(MP(w, v));
        E[v].push_back(MP(w, u));
    }
    for(int i=1; i<=N; i++) sort(E[i].begin(), E[i].end());
    all = N; maxx = INF;
    findroot(1, 0);
    Divide(root);
    printf("%lld\n", ans);
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章