【題目鏈接】
【算法】
樹鏈剖分
對於線段樹的每個節點,記錄這段區間的最小值,最小值的個數,值爲0的個數,此外,還要維護兩個懶惰標記
【代碼】
本題細節很多,寫程序時要認真嚴謹!
#include<bits/stdc++.h>
using namespace std;
#define MAXN 100010
#define MAXLOG 20
const int INF = 1e9;
int i,n,m,tot,opt,u,v,c,x,y,timer,Lca,tmp;
int dep[MAXN],dfn[MAXN],head[MAXN],size[MAXN],anc[MAXN][MAXLOG],fa[MAXN],top[MAXN],son[MAXN];
struct Edge
{
int to,nxt;
} e[MAXN<<1];
struct SegmentTree
{
struct Node
{
int l,r,sum,cnt,Min,taga,tagb;
} Tree[MAXN<<2];
inline void build(int index,int l,int r)
{
int mid;
Tree[index].l = l; Tree[index].r = r;
Tree[index].sum = Tree[index].cnt = r - l + 1;
Tree[index].taga = -1;
Tree[index].tagb = 0;
Tree[index].Min = 0;
if (l == r) return;
mid = (l + r) >> 1;
build(index<<1,l,mid);
build(index<<1|1,mid+1,r);
}
inline void pushdown(int index)
{
int l = Tree[index].l,r = Tree[index].r;
int mid = (l + r) >> 1;
if (Tree[index].taga != -1)
{
Tree[index<<1].sum = mid - l + 1;
if (!Tree[index].taga) Tree[index<<1].cnt = mid - l + 1;
else Tree[index<<1].cnt = 0;
Tree[index<<1].Min = Tree[index].taga;
Tree[index<<1|1].sum = r - mid;
if (!Tree[index].taga) Tree[index<<1|1].cnt = r - mid;
else Tree[index<<1|1].cnt = 0;
Tree[index<<1|1].Min = Tree[index].taga;
Tree[index<<1].tagb = Tree[index<<1|1].tagb = 0;
Tree[index<<1].taga = Tree[index<<1|1].taga = Tree[index].taga;
Tree[index].taga = -1;
}
if (Tree[index].tagb)
{
Tree[index<<1].Min += Tree[index].tagb;
if (!Tree[index<<1].Min) Tree[index<<1].cnt = Tree[index<<1].sum;
else Tree[index<<1].cnt = 0;
Tree[index<<1|1].Min += Tree[index].tagb;
if (!Tree[index<<1|1].Min) Tree[index<<1|1].cnt = Tree[index<<1|1].sum;
else Tree[index<<1|1].cnt = 0;
if (Tree[index<<1].taga != -1) Tree[index<<1].taga += Tree[index].tagb;
else Tree[index<<1].tagb += Tree[index].tagb;
if (Tree[index<<1|1].taga != -1) Tree[index<<1|1].taga += Tree[index].tagb;
else Tree[index<<1|1].tagb += Tree[index].tagb;
Tree[index].tagb = 0;
}
}
inline void update(int index)
{
Tree[index].Min = min(Tree[index<<1].Min,Tree[index<<1|1].Min);
Tree[index].cnt = Tree[index<<1].cnt + Tree[index<<1|1].cnt;
if (Tree[index<<1].Min < Tree[index<<1|1].Min) Tree[index].sum = Tree[index<<1].sum;
else if (Tree[index<<1|1].Min < Tree[index<<1].Min) Tree[index].sum = Tree[index<<1|1].sum;
else Tree[index].sum = Tree[index<<1].sum + Tree[index<<1|1].sum;
}
inline void modify(int index,int l,int r,int val)
{
int mid;
if (l > r) return;
if (Tree[index].l == l && Tree[index].r == r)
{
Tree[index].Min = val;
Tree[index].taga = val;
Tree[index].tagb = 0;
Tree[index].sum = r - l + 1;
if (!val) Tree[index].cnt = r - l + 1;
else Tree[index].cnt = 0;
return;
}
pushdown(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= r) modify(index<<1,l,r,val);
else if (mid + 1 <= l) modify(index<<1|1,l,r,val);
else
{
modify(index<<1,l,mid,val);
modify(index<<1|1,mid+1,r,val);
}
update(index);
}
inline void add(int index,int l,int r,int val)
{
int mid;
if (l > r) return;
if (Tree[index].l == l && Tree[index].r == r)
{
Tree[index].Min += val;
if (Tree[index].taga != -1) Tree[index].taga += val;
else Tree[index].tagb += val;
if (!Tree[index].Min) Tree[index].cnt = Tree[index].sum;
else Tree[index].cnt = 0;
return;
}
pushdown(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= r) add(index<<1,l,r,val);
else if (mid + 1 <= l) add(index<<1|1,l,r,val);
else
{
add(index<<1,l,mid,val);
add(index<<1|1,mid+1,r,val);
}
update(index);
}
inline int query_min(int index,int l,int r)
{
int mid;
if (l > r) return INF;
if (Tree[index].l == l && Tree[index].r == r) return Tree[index].Min;
pushdown(index);
mid = (Tree[index].l + Tree[index].r) >> 1;
if (mid >= r) return query_min(index<<1,l,r);
else if (mid + 1 <= l) return query_min(index<<1|1,l,r);
else return min(query_min(index<<1,l,mid),query_min(index<<1|1,mid+1,r));
}
inline int query()
{
return Tree[1].cnt - 1;
}
} T;
inline void add(int u,int v)
{
tot++;
e[tot] = (Edge){v,head[u]};
head[u] = tot;
}
inline void dfs1(int u)
{
int i,v;
size[u] = 1;
anc[u][0] = fa[u];
for (i = 1; i < MAXLOG; i++)
{
if (dep[u] < (1 << i)) break;
anc[u][i] = anc[anc[u][i-1]][i-1];
}
for (i = head[u]; i; i = e[i].nxt)
{
v = e[i].to;
if (fa[u] != v)
{
dep[v] = dep[u] + 1;
fa[v] = u;
dfs1(v);
size[u] += size[v];
if (size[v] > size[son[u]]) son[u] = v;
}
}
}
inline void dfs2(int u,int tp)
{
int i,v;
dfn[u] = ++timer;
top[u] = tp;
if (son[u]) dfs2(son[u],tp);
for (i = head[u]; i; i = e[i].nxt)
{
v = e[i].to;
if (fa[u] != v && son[u] != v) dfs2(v,v);
}
}
inline void solve1(int u,int v,int c)
{
int tu = top[u],tv = top[v];
while (tu != tv)
{
T.modify(1,dfn[tv],dfn[v],c);
v = fa[tv]; tv = top[v];
}
T.modify(1,dfn[u]+1,dfn[v],c);
}
inline void solve2(int u,int v,int c)
{
int tu = top[u],tv = top[v];
while (tu != tv)
{
T.add(1,dfn[tv],dfn[v],c);
v = fa[tv]; tv = top[v];
}
T.add(1,dfn[u]+1,dfn[v],c);
}
inline int query_min(int u,int v)
{
int tu = top[u],tv = top[v],ans = INF;
while (tu != tv)
{
ans = min(ans,T.query_min(1,dfn[tv],dfn[v]));
v = fa[tv]; tv = top[v];
}
ans = min(ans,T.query_min(1,dfn[u]+1,dfn[v]));
return ans;
}
inline int lca(int x,int y)
{
int i,t;
if (dep[x] > dep[y]) swap(x,y);
t = dep[y] - dep[x];
for (i = 0; i < MAXLOG; i++)
{
if (t & (1 << i))
y = anc[y][i];
}
if (x == y) return x;
for (i = MAXLOG - 1; i >= 0; i--)
{
if (anc[x][i] != anc[y][i])
{
x = anc[x][i];
y = anc[y][i];
}
}
return fa[x];
}
template <typename T> inline void read(T &x)
{
int f = 1; x = 0;
char c = getchar();
for (; !isdigit(c); c = getchar()) { if (c == '-') f = -f; }
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
x *= f;
}
template <typename T> inline void write(T x)
{
if (x < 0)
{
putchar('-');
x = -x;
}
if (x > 9) write(x/10);
putchar(x%10+'0');
}
template <typename T> inline void writeln(T x)
{
write(x);
puts("");
}
int main() {
read(n); read(m);
for (i = 1; i < n; i++)
{
read(x); read(y);
add(x,y);
add(y,x);
}
dfs1(1);
dfs2(1,1);
T.build(1,1,timer);
while (m--)
{
read(opt);
if (opt == 1)
{
read(u); read(v); read(c);
Lca = lca(u,v);
solve1(Lca,u,c);
solve1(Lca,v,c);
} else
{
read(u); read(v); read(c);
Lca = lca(u,v);
tmp = min(query_min(Lca,u),query_min(Lca,v));
if (tmp + c < 0) c = -tmp;
solve2(Lca,u,c);
solve2(Lca,v,c);
}
writeln(T.query());
}
return 0;
}