題目鏈接
很容易想到的是,我們可以對首先對每個點連出去的邊按照邊顏色進行排序,這樣就可以保證一段相同的是會在一起出現的,然後我們可以進行點分治,再利用線段樹維護深度對應的最大權,我們需要兩棵線段樹,一棵維護不同值,一棵維護相同值,另外要保證我們每次初始化需要到位,不然很有可能被卡70或80pts就比較的難受了。
8 2 4 4
-5 3
1 2 2
1 5 1
1 8 2
2 3 1
2 4 1
5 6 2
5 7 1
ans:-4
8 4 1 5
-7 9 6 1
1 2 1
1 3 2
1 4 1
2 5 1
5 6 2
3 7 1
3 8 3
ans:17
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
#include <unordered_map>
#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define Big_INF 0x3f3f3f3f3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
#define MP(a, b) make_pair(a, b)
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int maxN = 2e5 + 7;
int N, M, L, R;
ll c[maxN], ans;
vector<pair<int, int>> E[maxN];
bool vis[maxN];
int all, root, maxx, siz[maxN], son[maxN];
void findroot(int u, int fa)
{
siz[u] = 1; son[u] = 0;
int v;
for(pair<int, int> i : E[u])
{
v = i.second;
if(vis[v] || v == fa) continue;
findroot(v, u);
siz[u] += siz[v];
son[u] = max(son[u], siz[v]);
}
son[u] = max(son[u], all - siz[u]);
if(maxx > son[u]) { maxx = son[u]; root = u; }
}
struct BIT_Tree
{
ll tree[maxN << 2];
bool clear[maxN << 2];
void pushdown(int rt)
{
if(clear[rt])
{
clear[lsn] = clear[rsn] = true;
tree[lsn] = tree[rsn] = -Big_INF;
clear[rt] = false;
}
}
void pushup(int rt) { tree[rt] = max(tree[lsn], tree[rsn]); }
void update(int rt, int l, int r, int pos, ll val)
{
if(l == r) { tree[rt] = max(tree[rt], val); return; }
pushdown(rt);
int mid = HalF;
if(pos <= mid) update(Lson, pos, val);
else update(Rson, pos, val);
pushup(rt);
}
ll query(int rt, int l, int r, int ql, int qr)
{
if(ql <= l && qr >= r) return tree[rt];
pushdown(rt);
int mid = HalF;
if(qr <= mid) return query(QL);
else if(ql > mid) return query(QR);
else return max(query(QL), query(QR));
}
} t[2]; //old, same
int deep[maxN], max_deep, Stap[maxN], Stop;
ll dis[maxN], sav[maxN];
void dfs(int u, int fa, int old_col)
{
Stap[++Stop] = u;
max_deep = max(max_deep, deep[u]);
int v;
for(pair<int, int> i : E[u])
{
v = i.second;
if(vis[v] || v == fa) continue;
dis[v] = dis[u] + (i.first == old_col ? 0 : c[i.first]);
deep[v] = deep[u] + 1;
if(deep[v] <= R) dfs(v, u, i.first);
}
}
void Divide(int u)
{
vis[u] = true; t[0].clear[1] = true; t[1].clear[1] = true; t[0].tree[1] = -Big_INF; t[1].tree[1] = -Big_INF;
int v, las = -1, old_max_deep = 0;
for(pair<int, int> i : E[u])
{
v = i.second;
if(vis[v]) continue;
max_deep = 0; deep[v] = 1; dis[v] = c[i.first];
dfs(v, u, i.first);
if(las ^ i.first)
{
old_max_deep = min(R - 1, old_max_deep);
for(int j=1; j<=old_max_deep; j++)
{
t[0].update(1, 1, N, j, sav[j]);
sav[j] = -Big_INF;
}
old_max_deep = max_deep; t[1].clear[1] = true; t[1].tree[1] = -Big_INF;
for(int j=1, id; j<=Stop; j++)
{
id = Stap[j];
if(deep[id] >= L) ans = max(ans, dis[id]);
if(deep[id] < R) ans = max(ans, dis[id] + t[0].query(1, 1, N, L - deep[id], R - deep[id]));
}
while(Stop)
{
int id = Stap[Stop--];
if(deep[id] < R && sav[deep[id]] < dis[id])
{
t[1].update(1, 1, N, deep[id], dis[id]);
sav[deep[id]] = dis[id];
}
}
las = i.first;
}
else
{
for(int j=1, id; j<=Stop; j++)
{
id = Stap[j];
if(deep[id] >= L) ans = max(ans, dis[id]);
if(deep[id] < R) ans = max(ans, dis[id] + max(t[0].query(1, 1, N, L - deep[id], R - deep[id]), t[1].query(1, 1, N, L - deep[id], R - deep[id]) - c[i.first]));
}
old_max_deep = max(old_max_deep, max_deep);
while(Stop)
{
int id = Stap[Stop--];
if(deep[id] < R && sav[deep[id]] < dis[id])
{
t[1].update(1, 1, N, deep[id], dis[id]);
sav[deep[id]] = dis[id];
}
}
}
}
old_max_deep = min(old_max_deep, R - 1);
for(int i=1; i<=old_max_deep; i++) sav[i] = -Big_INF;
int totsiz = all;
for(pair<int, int> i : E[u])
{
v = i.second;
if(vis[v]) continue;
all = siz[v] > siz[u] ? totsiz - siz[u] : siz[v];
maxx = INF;
findroot(v, u);
Divide(root);
}
}
inline void init()
{
Stop = 0; ans = -Big_INF;
for(int i=1; i<=N; i++) { vis[i] = false; E[i].clear(); sav[i] = -Big_INF; }
}
int main()
{
scanf("%d%d%d%d", &N, &M, &L, &R);
init();
for(int i=1; i<=M; i++) scanf("%lld", &c[i]);
for(int i=1, u, v, w; i<N; i++)
{
scanf("%d%d%d", &u, &v, &w);
E[u].push_back(MP(w, v));
E[v].push_back(MP(w, u));
}
for(int i=1; i<=N; i++) sort(E[i].begin(), E[i].end());
all = N; maxx = INF;
findroot(1, 0);
Divide(root);
printf("%lld\n", ans);
return 0;
}