時間限制: 1 Sec 內存限制: 128 MB
題目描述
騙分過樣例,暴力出奇跡。
關於樹的算法有一大堆,樣樣都是毒瘤。
比如說 NOIP2018 提高組的 D2T3,如果會動態 DP 的做法那麼就馬上想到正解,但是 Tweetuzki 不會動態 DP,就只好騙分了。
可惜樹題的碼量也是超級大的。聽說好多學長都會動態 DP,但是考場上調不出來,只好暴力分收場了。瘋狂暗示
Tweetuzki 當時暴力寫掛了,有 4 個點寫成了死循環……於是分數白白少了 16 分。Tweetuzki 一想起這事,不禁夙夜憂嘆,輾轉反側。
現在他又遇到一道毒瘤樹上問題了,他下定決心:這次一定要把暴力分寫滿!
題目是這樣的:
有一棵 n 個點的樹,邊有邊權,每個點有顏色 ci。求所有顏色不同的點對的距離之和。由於答案可能很大,你只需要輸出其對 998,244,353 取模的結果即可。
形式化地講,記 u 號點和 v 號點在樹上的距離爲 dist(u,v),求:
輸入
輸入文件將會遵循以下格式:
n type
c1 c2 ⋯ cn
u1 v1 w1
u2 v2 w2
⋮
un−1 vn−1 wn−1
第一行兩個正整數 n,type(2≤n≤2×105,1≤type≤6),其中 n 表示點數,type爲部分分類型,詳見數據範圍,type=0 表示樣例數據。
第二行輸入 n 個正整數 ci(1≤ci≤109),表示每個點的顏色。
接下來n−1 行,每行輸入三個正整數 ui,vi,wi(1≤ui<vi≤n,1≤wi≤109),描述這棵樹。
輸出
輸出一行一個非負整數,表示答案對 998,244,353 取模的結果。
樣例輸入 Copy
4 0 1 2 3 3 1 2 5 2 3 4 3 4 7
樣例輸出 Copy
90
提示
滿足條件的點對有 (1,2),(1,3),(1,4),(2,1),(2,3),(2,4),(3,1),(3,2),(4,1),(4,2),故答案爲 5+9+16+5+4+11+9+4+16+11=90。
Subtask #1:n≤300, type=1。
Subtask #2:n≤2 000, type≤2。
Subtask #3:n≤10 000, type≤3。
Subtask #4:對於第 i (1≤i≤n) 號點,ci=i。type=4。
Subtask #5 :對於第 i(1≤i<n)條邊,ui+1=vi。type=5。
Subtask #6:無特殊性質,type≤6。
題目要求不同顏色頂點間的距離和,我們轉化爲所有頂點間的距離和-相同顏色點間的距離和
對於所有頂點間的距離和,我們跑一遍圖,求出每條邊左右的頂點對數即可求出每條邊的貢獻,最終得到所有邊的貢獻
將顏色相同的頂點分別建立一棵虛樹,每一顆虛樹類似上面跑一遍圖即可
最終答案<<1即可
/**/
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <iostream>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include <string>
#include <stack>
#include <queue>
typedef long long LL;
using namespace std;
const long long mod = 998244353;
const int maxn = 200005;
int n, type, tot, cnt, top, len;
int c[maxn], b[maxn];
int head[maxn], sz[maxn], son[maxn], topf[maxn], f[maxn], dep[maxn], dfn[maxn];
LL ans, res, dis[maxn];
int e[maxn], s[maxn], dp[maxn];
bool vis[maxn];
vector<int> v[maxn];
vector<pair<int, LL> > g[maxn];
struct node
{
int v, w, next;
}a[maxn << 1];
bool cmp(int x, int y){
return dfn[x] < dfn[y];
}
void dfs(int x, int pre){
sz[x] = 1;
dep[x] = dep[pre] + 1;
f[x] = pre;
for (int i = head[x]; i != -1; i = a[i].next){
int v = a[i].v;
if(v == pre) continue;
dis[v] = (dis[x] + a[i].w) % mod;
dfs(v, x);
ans = (ans + 1LL * sz[v] * (n - sz[v]) % mod * a[i].w % mod) % mod;
sz[x] += sz[v];
if(sz[son[x]] < sz[v]) son[x] = v;
}
}
void dfs1(int x, int topfa){
topf[x] = topfa;
dfn[x] = ++cnt;
if(!son[x]) return ;
dfs1(son[x], topfa);
for (int i = head[x]; i != -1; i = a[i].next){
int v = a[i].v;
if(topf[v]) continue;
dfs1(v, v);
}
}
int LCA(int x, int y){
while(topf[x] != topf[y]){
if(dep[topf[x]] < dep[topf[y]]) swap(x, y);
x = f[topf[x]];
}
if(dep[x] > dep[y]) swap(x, y);
return x;
}
void add_edge(int u, int v){
if(u == n + 1) g[u].emplace_back(make_pair(v, 0));
else g[u].emplace_back(make_pair(v, (dis[v] - dis[u] + mod) % mod));
}
void insert(int x){
if(top <= 1){
s[++top] = x;
return ;
}
int lca = LCA(s[top], x);
if(lca == s[top]){
s[++top] = x;
return ;
}
while(top > 1 && dfn[lca] <= dfn[s[top - 1]]){
add_edge(s[top - 1], s[top]);
top--;
}
if(lca != s[top]) add_edge(lca, s[top]), s[top] = lca;
s[++top] = x;
}
void dfs2(int u){
dp[u] = vis[u];
for (auto x : g[u]){
int v = x.first;
LL w = x.second;
dfs2(v);
dp[u] += dp[v];
res = (res + 1LL * dp[v] * (len - dp[v]) % mod * w % mod) % mod;
}
g[u].clear();
}
int main()
{
//freopen("in.txt", "r", stdin);
//freopen("out.txt", "w", stdout);
memset(head, -1, sizeof(head));
scanf("%d %d", &n, &type);
for (int i = 1; i <= n; i++) scanf("%d", &c[i]), b[i] = c[i];
sort(b + 1, b + 1 + n);
int num = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; i++) c[i] = lower_bound(b + 1, b + 1 + num, c[i]) - b;
for (int i = 1; i <= n; i++) v[c[i]].emplace_back(i);
for (int i = 1, u, v, w; i < n; i++){
scanf("%d %d %d", &u, &v, &w);
a[tot] = node{v, w, head[u]}, head[u] = tot++;
a[tot] = node{u, w, head[v]}, head[v] = tot++;
}
dfs(1, 0);
dfs1(1, 1);
for (int i = 1; i <= num; i++){
if(v[i].empty()) continue;
len = v[i].size();
for (int j = 0; j < len; j++) e[j + 1] = v[i][j], vis[e[j + 1]] = true;
sort(e + 1, e + 1 + len, cmp);
s[top = 1] = n + 1;
for (int j = 1; j <= len; j++) insert(e[j]);
while(top > 1) add_edge(s[top - 1], s[top]), top--;
res = 0;
dfs2(n + 1);
for (int j = 1; j <= len; j++) vis[e[j]] = false;
ans = (ans - res + mod) % mod;
}
printf("%lld\n", (ans << 1) % mod);
return 0;
}
/*
8 3
1 2 3 1 3 3 1 2
1 2 1
2 4 2
2 5 2
5 6 3
5 7 3
1 3 4
3 8 4
*/