P2633 Count on a tree
Description
- 給定一棵N個節點的樹,每個點有一個權值,對於M個詢問(u,v,k),你需要回答u xor lastans和v這兩個節點間第K小的點權。其中lastans是上一個詢問的答案,初始爲0,即第一個詢問的u是明文。
Input
第一行兩個整數N,M。
第二行有N個整數,其中第i個整數表示點i的權值。
後面N-1行每行兩個整數(x,y),表示點x到點y有一條邊。
最後M行每行兩個整數(u,v,k),表示一組詢問。
Output
- M行,表示每個詢問的答案。
Sample Input
8 5 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8 2 5 1 0 5 2 10 5 3 11 5 4 110 8 2
Sample Output
2 8 9 105 7
Data Size
- N,M<=100000
題解:
- 主席樹。
- 跑到樹上的主席樹,挺好玩的。
- 首先想想模板主席樹是怎樣實現的:每個節點根據前一個節點建立。然後利用前綴和思想拿第r棵樹 - 第l-1棵樹得到[l,r]區間的信息,操作在這棵新樹上操作即可。
- 那麼到了樹上呢?
- 可以對於每一個節點,在它父親基礎上建樹。這樣每一個節點所保存的信息就是它自身到根節點這條鏈上的信息。
- 然後我們就可以解決某點到根路徑上的第k大啦!
- 等等,不是要解決x-y路徑上的第k大嗎?
- s[u]+s[v]−s[lca(u,v)]−s[fa[lca(u,v)]]。這不就表示成了x-y路徑的信息了嘛:D
#include <iostream>
#include <cstdio>
#include <algorithm>
#define N 200005
#define find(x) (lower_bound(b + 1, b + 1 + cnt, x) - b)
using namespace std;
struct T {int l, r, sum;} t[N << 5];
struct E {int next, to;} e[N * 2];
int n, m, num, cnt, dex, last;
int h[N], a[N], b[N], fat[N], dep[N];
int son[N], top[N], size[N], rt[N];
int read()
{
int x = 0, f = 1; char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
return x *= f;
}
void add(int u, int v)
{
e[++num].next = h[u];
e[num].to = v;
h[u] = num;
}
void dfs1(int x, int fath, int depth)
{
size[x] = 1, fat[x] = fath, dep[x] = depth;
int maxSon = 0;
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fath)
{
dfs1(e[i].to, x, depth + 1);
size[x] += size[e[i].to];
if(size[e[i].to] > maxSon)
{
maxSon = size[e[i].to];
son[x] = e[i].to;
}
}
}
void dfs2(int x, int head)
{
top[x] = head;
if(!son[x]) return;
dfs2(son[x], head);
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fat[x] && e[i].to != son[x])
dfs2(e[i].to, e[i].to);
}
int lca(int x, int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fat[top[x]];
}
if(dep[x] < dep[y]) return x;
return y;
}
int build(int l, int r)
{
int p = ++dex, mid = l + r >> 1;
if(l == r) return p;
t[p].l = build(l, mid), t[p].r = build(mid + 1, r);
return p;
}
int upd(int las, int l, int r, int val)
{
int p = ++dex, mid = l + r >> 1;
t[p].l = t[las].l, t[p].r = t[las].r, t[p].sum = t[las].sum + 1;
if(l == r) return p;
if(val <= mid) t[p].l = upd(t[las].l, l, mid, val);
else t[p].r = upd(t[las].r, mid + 1, r, val);
return p;
}
void dfs(int x)
{
rt[x] = upd(rt[fat[x]], 1, cnt, find(a[x]));
for(int i = h[x]; i != 0; i = e[i].next)
if(e[i].to != fat[x]) dfs(e[i].to);
}
int ask(int s1, int s2, int fa, int pa, int l, int r, int rank)
{
int size = t[t[s2].l].sum + t[t[s1].l].sum - t[t[fa].l].sum - t[t[pa].l].sum;
int mid = l + r >> 1;
if(l == r) return l;
if(rank <= size) return ask(t[s1].l, t[s2].l, t[fa].l, t[pa].l, l, mid, rank);
else return ask(t[s1].r, t[s2].r, t[fa].r, t[pa].r, mid + 1, r, rank - size);
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
a[i] = read(), b[++cnt] = a[i];
sort(b + 1, b + 1 + cnt);
cnt = unique(b + 1, b + 1 + cnt) - b - 1;
for(int i = 1; i < n; i++)
{
int u = read(), v = read();
add(u, v), add(v, u);
}
rt[0] = build(1, cnt);
dfs1(1, 0, 1), dfs2(1, 1), dfs(1);
for(int i = 1; i <= m; i++)
{
int u = read() ^ last, v = read(), rank = read(), head = lca(u, v);
int res = ask(rt[u], rt[v], rt[head], rt[fat[head]], 1, cnt, rank);
printf("%d\n", last = b[res]);
}
return 0;
}