樹統計(虛樹)

時間限制: 1 Sec  內存限制: 128 MB

題目描述

騙分過樣例,暴力出奇跡。
關於樹的算法有一大堆,樣樣都是毒瘤。
比如說 NOIP2018 提高組的 D2T3,如果會動態 DP 的做法那麼就馬上想到正解,但是 Tweetuzki 不會動態 DP,就只好騙分了。
可惜樹題的碼量也是超級大的。聽說好多學長都會動態 DP,但是考場上調不出來,只好暴力分收場了。瘋狂暗示
Tweetuzki 當時暴力寫掛了,有 4 個點寫成了死循環……於是分數白白少了 16 分。Tweetuzki 一想起這事,不禁夙夜憂嘆,輾轉反側。
現在他又遇到一道毒瘤樹上問題了,他下定決心:這次一定要把暴力分寫滿!
題目是這樣的:
有一棵 n 個點的樹,邊有邊權,每個點有顏色 ci。求所有顏色不同的點對的距離之和。由於答案可能很大,你只需要輸出其對 998,244,353 取模的結果即可。
形式化地講,記 u 號點和 v 號點在樹上的距離爲 dist(u,v),求:

輸入

輸入文件將會遵循以下格式:
n type
c1 c2 ⋯ cn
u1 v1 w1
u2 v2 w2

un−1 vn−1 wn−1
第一行兩個正整數 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示點數,type爲部分分類型,詳見數據範圍,type=0 表示樣例數據。
第二行輸入 n 個正整數 ci(1≤ci≤109),表示每個點的顏色。
接下來n−1 行,每行輸入三個正整數 ui,vi,wi(1≤ui<vi≤n,1≤wi≤109),描述這棵樹。

輸出

輸出一行一個非負整數,表示答案對 998,244,353 取模的結果。

樣例輸入 Copy

4 0
1 2 3 3
1 2 5
2 3 4
3 4 7

樣例輸出 Copy

90

提示

滿足條件的點對有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案爲 5+9+16+5+4+11+9+4+16+11=90。

Subtask #1:n≤300, type=1。
Subtask #2:n≤2 000, type≤2。
Subtask #3:n≤10 000, type≤3。
Subtask #4:對於第 i (1≤i≤n) 號點,ci=i。type=4。
Subtask #5 :對於第 i(1≤i<n)條邊,ui+1=vi。type=5。
Subtask #6:無特殊性質,type≤6。

 

題目要求不同顏色頂點間的距離和,我們轉化爲所有頂點間的距離和-相同顏色點間的距離和

對於所有頂點間的距離和,我們跑一遍圖,求出每條邊左右的頂點對數即可求出每條邊的貢獻,最終得到所有邊的貢獻

將顏色相同的頂點分別建立一棵虛樹,每一顆虛樹類似上面跑一遍圖即可

最終答案<<1即可

/**/
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <iostream>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <string>
#include <stack>
#include <queue>

typedef long long LL;
using namespace std;

const long long mod = 998244353;
const int maxn = 200005;

int n, type, tot, cnt, top, len;
int c[maxn], b[maxn];
int head[maxn], sz[maxn], son[maxn], topf[maxn], f[maxn], dep[maxn], dfn[maxn];
LL ans, res, dis[maxn];
int e[maxn], s[maxn], dp[maxn];
bool vis[maxn];

vector<int> v[maxn];
vector<pair<int, LL> > g[maxn];

struct node
{
	int v, w, next;
}a[maxn << 1];

bool cmp(int x, int y){
	return dfn[x] < dfn[y];
}

void dfs(int x, int pre){
	sz[x] = 1;
	dep[x] = dep[pre] + 1;
	f[x] = pre;
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(v == pre) continue;
		dis[v] = (dis[x] + a[i].w) % mod;
		dfs(v, x);
		ans = (ans + 1LL * sz[v] * (n - sz[v]) % mod * a[i].w % mod) % mod;
		sz[x] += sz[v];
		if(sz[son[x]] < sz[v]) son[x] = v;
	}
}

void dfs1(int x, int topfa){
	topf[x] = topfa;
	dfn[x] = ++cnt;
	if(!son[x]) return ;
	dfs1(son[x], topfa);
	for (int i = head[x]; i != -1; i = a[i].next){
		int v = a[i].v;
		if(topf[v]) continue;
		dfs1(v, v);
	}
}

int LCA(int x, int y){
	while(topf[x] != topf[y]){
		if(dep[topf[x]] < dep[topf[y]]) swap(x, y);
		x = f[topf[x]];
	}
	if(dep[x] > dep[y]) swap(x, y);
	return x;
}

void add_edge(int u, int v){
	if(u == n + 1) g[u].emplace_back(make_pair(v, 0));
	else g[u].emplace_back(make_pair(v, (dis[v] - dis[u] + mod) % mod));
}

void insert(int x){
	if(top <= 1){
		s[++top] = x;
		return ;
	}
	int lca = LCA(s[top], x);
	if(lca == s[top]){
		s[++top] = x;
		return ;
	}
	while(top > 1 && dfn[lca] <= dfn[s[top - 1]]){
		add_edge(s[top - 1], s[top]);
		top--;
	}
	if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;
	s[++top] = x;
}

void dfs2(int u){
	dp[u] = vis[u];
	for (auto x : g[u]){
		int v = x.first;
		LL w = x.second;
		dfs2(v);
		dp[u] += dp[v];
		res = (res + 1LL * dp[v] * (len - dp[v]) % mod * w % mod) % mod;
	}
	g[u].clear();
}

int main()
{
	//freopen("in.txt", "r", stdin);
	//freopen("out.txt", "w", stdout);

	memset(head, -1, sizeof(head));
	scanf("%d %d", &n, &type);
	for (int i = 1; i <= n; i++) scanf("%d", &c[i]), b[i] = c[i];
	sort(b + 1, b + 1 + n);
	int num = unique(b + 1, b + 1 + n) - b - 1;
	for (int i = 1; i <= n; i++) c[i] = lower_bound(b + 1, b + 1 + num, c[i]) - b;
	for (int i = 1; i <= n; i++) v[c[i]].emplace_back(i);
	for (int i = 1, u, v, w; i < n; i++){
		scanf("%d %d %d", &u, &v, &w);
		a[tot] = node{v, w, head[u]}, head[u] = tot++;
		a[tot] = node{u, w, head[v]}, head[v] = tot++;
	}
	dfs(1, 0);
	dfs1(1, 1);
	for (int i = 1; i <= num; i++){
		if(v[i].empty()) continue;
		len = v[i].size();
		for (int j = 0; j < len; j++) e[j + 1] = v[i][j], vis[e[j + 1]] = true;
		sort(e + 1, e + 1 + len, cmp);
		s[top = 1] = n + 1;
		for (int j = 1; j <= len; j++) insert(e[j]);
		while(top > 1) add_edge(s[top - 1], s[top]), top--;
		res = 0;
		dfs2(n + 1);
		for (int j = 1; j <= len; j++) vis[e[j]] = false;
		ans = (ans - res + mod) % mod;
	}
	printf("%lld\n", (ans << 1) % mod);

	return 0;
}
/*
8 3
1 2 3 1 3 3 1 2
1 2 1
2 4 2
2 5 2
5 6 3
5 7 3
1 3 4
3 8 4
*/

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章