樹統計 (樹鏈剖分+線段樹)

時間限制: 1 Sec  內存限制: 128 MB

題目描述

然而,這一切宛如一度揉過的複寫紙,無不同原來有着少許然而卻是無可挽回的差異。—— 村上春樹

關於樹的算法有一大堆,樣樣都是毒瘤。

比如說 2019 CSP-S 的樹論題,如果擅長樹形數據結構馬上想到正解,但是 3edc2wsx1qaz 並不擅長,就只好騙分了。

3edc2wsx1qaz 當時數組開小了,慘遭 RE,3edc2wsx1qaz 一想起這事,不禁夙夜憂嘆,輾轉反側。

現在他又遇到一道毒瘤的樹上問題了,他下定決心:這次一定要寫出正解!

題目是這樣的:

有一顆有n個點的樹,每條邊有一個權值ai。樹的根節點爲1號節點。定義一對點對(u,v)的距離dist(u,v)爲在u到v的簡單路徑上的所有邊的邊權的異或。

你需要進行q次操作,操作分爲兩種:
1.將x點與它父親所連的邊的邊權異或w。
2.詢問以節點y爲根的子樹中所有點對的距離之和,答案對998244353取模
也就是說,對於每次 2 操作,設以節點y爲根的子樹的節點集合爲subtrss(y), 你需要求出以下式子的值:


 

輸入

爲了方便你獲取部分分,我們會告訴你測試點編號。

第一行輸入三個正整數n,q,r(2≤n≤10^5,2≤q≤10^5,1≤r≤50),表示樹的節點數,操作數,該測試點編號。

接下來n-1行每行三個正整數u,v,w,表示有一條連接u,v,權值爲w的邊。(1≤u≤n,1≤v≤n,0≤w<2^10)

接下來q行,每行開頭輸入一個數opt(opt=1 或opt=2 ),表示操作類型。

若opt=1,則再輸入兩個數x,w(1<x≤n,0≤w<2^10),表示將x號點與它父親所連的邊的邊權異或w。

若opt=2 ,則再輸入一個數y,表示一次詢問,你需要輸出以節點y爲根的子樹中所有點對的距離之和。

輸出

輸出若干行,對於每次 2 操作,輸出一個正整數,表示答案。

樣例輸入 Copy

8 8 0
2 1 0
3 1 0
4 3 0
5 2 1
6 5 1
7 5 0
8 1 0
1 4 0
2 7
1 3 0
2 5
1 5 1
1 4 0
1 5 0
2 1

樣例輸出 Copy

0
4
14

提示

樣例解釋:
由於這組數據爲樣例,所以r=0。
保證測試數據中1≤r≤50


 

對於一顆子樹內的任意兩點(x,y)之間距離(如題所述的距離)爲dis(1,x)^dis(1,y)

那麼我們知道統計子樹按位拆開後對應二進制爲1的個數即可。

對於更新操作,修改邊權爲w,枚舉二進制位,當且僅當w對應的二進制爲1時必發生當前點及子樹對應位0和1個數的互換,用線段樹維護即可。

 

/**/
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <iostream>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <string>
#include <stack>
#include <queue>

typedef long long LL;
using namespace std;

const long long mod = 998244353;
const int maxn = 1e5 + 5;

int n, q, r, tot, cnt;
int head[maxn], dfn[maxn], sz[maxn], son[maxn], top[maxn], f[maxn], id[maxn], w[maxn];
int tr[11][maxn << 2], lzy[11][maxn << 2];

struct node
{
	int v, w, next;
}a[maxn << 1];

void dfs(int x, int pre){
	f[x] = pre;
	sz[x] = 1;
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == pre) continue;
		w[v] = w[x] ^ a[i].w;
		dfs(v, x);
		sz[x] += sz[v];
		if(sz[son[x]] < sz[v]) son[x] = v;
	}
}

void dfs1(int x, int topf){
	top[x] = topf;
	dfn[x] = ++cnt;
	id[cnt] = x;
	if(son[x]) dfs1(son[x], topf);
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == f[x] || v == son[x]) continue;
		dfs1(v, v);
	}
}

void up(int rt){
	for (int i = 0; i < 10; i++){
		tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1];
	}
}

void up(int rt, int i){
	tr[i][rt] = tr[i][rt << 1] + tr[i][rt << 1 | 1];
}

void down(int rt, int l, int r){
	int mid = (l + r) >> 1;
	for (int i = 0; i < 10; i++){
		if(lzy[i][rt]){
			lzy[i][rt << 1] ^= 1;
			lzy[i][rt << 1 | 1] ^= 1;
			tr[i][rt << 1] = (mid - l + 1) - tr[i][rt << 1];
			tr[i][rt << 1 | 1] = (r - mid) - tr[i][rt << 1 | 1];
			lzy[i][rt] = 0;
		}
	}
}

void build(int rt, int l, int r){
	if(l == r){
		for (int i = 0; i < 10; i++){
			if(1 << i & w[id[l]]) tr[i][rt] = 1;
			else tr[i][rt] = 0;
		}
		return ;
	}
	int mid = (l + r) >> 1;
	build(rt << 1, l, mid);
	build(rt << 1 | 1, mid + 1, r);
	up(rt);
}

void update(int rt, int l, int r, int L, int R, int i){
	if(L <= l && r <= R){
		tr[i][rt] = r - l + 1 - tr[i][rt];
		lzy[i][rt] ^= 1;
		return ;
	}
	down(rt, l, r);
	int mid = (l + r) >> 1;
	if(mid >= L) update(rt << 1, l, mid, L, R, i);
	if(mid < R) update(rt << 1 | 1, mid + 1, r, L, R, i);
	up(rt, i);
}

int query(int rt, int l, int r, int L, int R, int i){
	if(L <= l && r <= R) return tr[i][rt];
	down(rt, l, r);
	int mid = (l + r) >> 1, ans = 0;
	if(mid >= L) ans += query(rt << 1, l, mid, L, R, i);
	if(mid < R) ans += query(rt << 1 | 1, mid + 1, r, L, R, i);
	return ans;
}

void modify(int x, int W){
	for (int i = 0; i < 10; i++){
		if(W >> i & 1) update(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i);
	}
	w[x] ^= W;
}

LL sum(int x){
	LL ans = 0;
	for (int i = 0; i < 10; i++){
		int num = query(1, 1, n, dfn[x], dfn[x] + sz[x] - 1, i);
		ans = (ans + 1LL * num * (sz[x] - num) % mod * (1 << i) % mod);
	}
	return ans;
}

int main()
{
	//freopen("in.txt", "r", stdin);
	//freopen("out.txt", "w", stdout);

	memset(head, -1, sizeof(head));
	scanf("%d %d %d", &n, &q, &r);
	for (int i = 1, u, v, w; i < n; i++){
		scanf("%d %d %d", &u, &v, &w);
		a[tot] = node{v, w, head[u]}, head[u] = tot++;
		a[tot] = node{u, w, head[v]}, head[v] = tot++;
	}
	dfs(1, 0);
	dfs1(1, 1);
	build(1, 1, n);
	for (int i = 1, op, x, y, w; i <= q; i++){
		scanf("%d", &op);
		if(op == 1){
			scanf("%d %d", &x, &w);
			modify(x, w);
		}else{
			scanf("%d", &y);
			printf("%lld\n", (sum(y) << 1) % mod);
		}
	}

	return 0;
}
/**/

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章