hdu6723 wls 的樹(LCA+線段樹合併)

 wls有一棵有根樹,其中的點從1到n標號,其中1是樹根。每次wls可以執行兩種操作中的一個:

(1)選定一個點x,將以x爲根的子樹變成一條按照編號排序的鏈,其中編號最大的作爲新的子樹的根(成爲原來x的父親節點的兒子,如果原來x沒有父親節點則新的子樹的根也沒有父親節點)。

(2)查詢兩個點之間的最短路徑上經過了多少邊。

 

對每一個點都建一個線段樹。

對於操作1

將x和其子樹所有點進行合併。

對於操作2的查詢

如果兩個點都沒有被拉成鏈,則直接求ans=dis(u)+dis(v)-dis(lca(u,v))*2。

如果兩個點在同一個鏈上,則通過線段樹求出在鏈上2點的距離(即U,V間有多少個數)。

如果不再同一個鏈上,則ans=dis(U的鏈頭)+dis(V的鏈頭)-lca(U的鏈頭,V的鏈頭)*2+U到鏈頭的距離(即大於U的點的個數)+V到鏈頭的距離(即大於V的點的個數)

#include<map>
#include<stack>
#include<queue>
#include<cstdio>
#include<algorithm>
#include<vector>
#include <assert.h>
#include<cstring>
#include<cmath>
#include<iostream>
#include<string>
#include<bitset>
using namespace std;
typedef long long ll;
#define mid ((l+r)>>1)
const int N = 200020;
int  rt[N * 40], sum[N * 40], ls[N * 40], rs[N * 40], fa[N][20], depth[N], flag[N];
int fb[N];
vector<int>e[N];
int n;
int tot;
int newNode() {
	tot++;
	ls[tot] = rs[tot] = sum[tot] = 0;
	return tot;
}
int build(int &p, int l, int r, int x) {
	p = newNode();
	if (l == r) {
		sum[p] = 1;
		return p;
	}
	if (x <= mid)
		build(ls[p], l, mid, x);
	else
		build(rs[p], mid + 1, r, x);
	sum[p] = 1;

}
void dfs(int p, int f, int dep) {
	flag[p] = 0;
	build(rt[p], 1, n, p);
	fa[p][0] = f;
	depth[p] = dep;
	for (int i = 0; i < e[p].size(); i++) {
		int v = e[p][i];
		if (v != f) {
			dfs(v, p, dep + 1);
		}
	}
}
int Union(int u, int v, int l, int r) {
	if (u == 0 || v == 0)return u + v;
	int p = newNode();
	if (l == r) {
		sum[p] = sum[u] + sum[v];
		return p;
	}
	ls[p] = Union(ls[u], ls[v], l, mid);
	rs[p] = Union(rs[u], rs[v], mid + 1, r);
	sum[p] = sum[rs[p]] + sum[ls[p]];
	return p;
}
int find(int x) {
	if (x == fb[x])return x;
	return fb[x] = find(fb[x]);
}
void dfs1(int p, int fg) {
	if (flag[p]) {
		fb[p] = fg;
		return;
	}
	fb[p] = fg;
	flag[p] = 1;
	for (int i = 0; i < e[p].size(); i++) {
		int v = e[p][i];
		if (v != fa[p][0]) {
			dfs1(v, fg);
			rt[p] = Union(rt[p], rt[v], 1, n);
		}
	}
}
int getlca(int x, int y) {
	if (depth[x] < depth[y]) {
		swap(x, y);
	}
	for (int i = 17; i >= 0; i--) {
		if ((1 << i) <= depth[x] - depth[y]) {
			x = fa[x][i];
		}
	}
	if (x == y)return x;
	for (int i = 17; i >= 0; i--) {
		if (fa[x][i] != fa[y][i]) {
			x = fa[x][i];
			y = fa[y][i];
		}
	}
	return fa[x][0];
}
int query(int p, int l, int r, int x, int y) {
	//if (p == 0)return 0;
	if (l == x && y == r) {
		return sum[p];
	}
	if (y <= mid) {
		return  query(ls[p], l, mid, x, y);
	}
	else if (x > mid) {
		return query(rs[p], mid + 1, r, x, y);
	}
	else {
		return  query(ls[p], l, mid, x, mid) + query(rs[p], mid + 1, r, mid + 1, y);
	}
}
int main()
{
	int  u, v;
	int t;
	scanf("%d", &t);
	while (t--) {
		tot = 0;
		scanf("%d", &n);
		for (int i = 1; i < n; i++) {
			scanf("%d%d", &u, &v);
			e[u].push_back(v);
			e[v].push_back(u);
		}
		for (int i = 1; i <= n; i++)fb[i] = i;
		memset(fa, 0, sizeof(fa));
		dfs(1, 0, 0);
		for (int i = 1; i <= 17; i++) {
			for (int j = 1; j <= n; j++) {
				fa[j][i] = fa[fa[j][i - 1]][i - 1];
			}
		}
		int q, f;
		scanf("%d", &q);
		for (int i = 0; i < q; i++) {
			scanf("%d", &f);
			if (f == 1) {
				scanf("%d", &u);
				if (!flag[u])dfs1(u, u);
			}
			else {
				scanf("%d%d", &u, &v);
				int x = find(u);
				int y = find(v);
				int ans = 0;
				if (x == y) {
					if (u < v)swap(u, v);
					ans = query(rt[x], 1, n, v, u) - 1;
				}
				else {
					int lca = getlca(x, y);
					ans = depth[x] + depth[y] - depth[lca] * 2;
					ans += sum[rt[x]] - query(rt[x], 1, n, 1, u) + sum[rt[y]] - query(rt[y], 1, n, 1, v);

				}
				printf("%d\n", ans);
			}
		}
		for (int i = 1; i <= n; i++)e[i].clear();
	}
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章