wls有一棵有根樹,其中的點從1到n標號,其中1是樹根。每次wls可以執行兩種操作中的一個:
(1)選定一個點x,將以x爲根的子樹變成一條按照編號排序的鏈,其中編號最大的作爲新的子樹的根(成爲原來x的父親節點的兒子,如果原來x沒有父親節點則新的子樹的根也沒有父親節點)。
(2)查詢兩個點之間的最短路徑上經過了多少邊。
對每一個點都建一個線段樹。
對於操作1
將x和其子樹所有點進行合併。
對於操作2的查詢
如果兩個點都沒有被拉成鏈,則直接求ans=dis(u)+dis(v)-dis(lca(u,v))*2。
如果兩個點在同一個鏈上,則通過線段樹求出在鏈上2點的距離(即U,V間有多少個數)。
如果不再同一個鏈上,則ans=dis(U的鏈頭)+dis(V的鏈頭)-lca(U的鏈頭,V的鏈頭)*2+U到鏈頭的距離(即大於U的點的個數)+V到鏈頭的距離(即大於V的點的個數)
#include<map>
#include<stack>
#include<queue>
#include<cstdio>
#include<algorithm>
#include<vector>
#include <assert.h>
#include<cstring>
#include<cmath>
#include<iostream>
#include<string>
#include<bitset>
using namespace std;
typedef long long ll;
#define mid ((l+r)>>1)
const int N = 200020;
int rt[N * 40], sum[N * 40], ls[N * 40], rs[N * 40], fa[N][20], depth[N], flag[N];
int fb[N];
vector<int>e[N];
int n;
int tot;
int newNode() {
tot++;
ls[tot] = rs[tot] = sum[tot] = 0;
return tot;
}
int build(int &p, int l, int r, int x) {
p = newNode();
if (l == r) {
sum[p] = 1;
return p;
}
if (x <= mid)
build(ls[p], l, mid, x);
else
build(rs[p], mid + 1, r, x);
sum[p] = 1;
}
void dfs(int p, int f, int dep) {
flag[p] = 0;
build(rt[p], 1, n, p);
fa[p][0] = f;
depth[p] = dep;
for (int i = 0; i < e[p].size(); i++) {
int v = e[p][i];
if (v != f) {
dfs(v, p, dep + 1);
}
}
}
int Union(int u, int v, int l, int r) {
if (u == 0 || v == 0)return u + v;
int p = newNode();
if (l == r) {
sum[p] = sum[u] + sum[v];
return p;
}
ls[p] = Union(ls[u], ls[v], l, mid);
rs[p] = Union(rs[u], rs[v], mid + 1, r);
sum[p] = sum[rs[p]] + sum[ls[p]];
return p;
}
int find(int x) {
if (x == fb[x])return x;
return fb[x] = find(fb[x]);
}
void dfs1(int p, int fg) {
if (flag[p]) {
fb[p] = fg;
return;
}
fb[p] = fg;
flag[p] = 1;
for (int i = 0; i < e[p].size(); i++) {
int v = e[p][i];
if (v != fa[p][0]) {
dfs1(v, fg);
rt[p] = Union(rt[p], rt[v], 1, n);
}
}
}
int getlca(int x, int y) {
if (depth[x] < depth[y]) {
swap(x, y);
}
for (int i = 17; i >= 0; i--) {
if ((1 << i) <= depth[x] - depth[y]) {
x = fa[x][i];
}
}
if (x == y)return x;
for (int i = 17; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int query(int p, int l, int r, int x, int y) {
//if (p == 0)return 0;
if (l == x && y == r) {
return sum[p];
}
if (y <= mid) {
return query(ls[p], l, mid, x, y);
}
else if (x > mid) {
return query(rs[p], mid + 1, r, x, y);
}
else {
return query(ls[p], l, mid, x, mid) + query(rs[p], mid + 1, r, mid + 1, y);
}
}
int main()
{
int u, v;
int t;
scanf("%d", &t);
while (t--) {
tot = 0;
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; i++)fb[i] = i;
memset(fa, 0, sizeof(fa));
dfs(1, 0, 0);
for (int i = 1; i <= 17; i++) {
for (int j = 1; j <= n; j++) {
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
int q, f;
scanf("%d", &q);
for (int i = 0; i < q; i++) {
scanf("%d", &f);
if (f == 1) {
scanf("%d", &u);
if (!flag[u])dfs1(u, u);
}
else {
scanf("%d%d", &u, &v);
int x = find(u);
int y = find(v);
int ans = 0;
if (x == y) {
if (u < v)swap(u, v);
ans = query(rt[x], 1, n, v, u) - 1;
}
else {
int lca = getlca(x, y);
ans = depth[x] + depth[y] - depth[lca] * 2;
ans += sum[rt[x]] - query(rt[x], 1, n, 1, u) + sum[rt[y]] - query(rt[y], 1, n, 1, v);
}
printf("%d\n", ans);
}
}
for (int i = 1; i <= n; i++)e[i].clear();
}
return 0;
}