傳送門
解:就是求圖中所有路徑%3爲0,1,2的路徑和。
#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define ms(_data,v) memset(_data,v,sizeof(_data))
#define SZ(a) int((a).size())
using namespace std;
typedef long long ll;
const ll inf=0x3f3f3f3f;
const int maxn=1e4+5;
const int mod=1e9+7;
template <typename _Tp> il void read(_Tp&x) {
char ch;bool flag=0;x=0;
while(ch=getchar(),!isdigit(ch)) if(ch=='-')flag=1;
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
if(flag) x=-x;
}
il int Add(ll &x,ll y) {return x=x+y>=mod?x+y-mod:x+y;}
il int Mul(ll &x,ll y) {return x=x*y>=mod?x*y%mod:x*y;}
int n,k,sz[maxn],root,mx,all,cnt;
ll dis[maxn],q[maxn],ans0,ans1,ans2;
bool vis[maxn];
struct node{
int to,w;
};
vector<node> mp[maxn];
il void getroot(int x,int fa){
sz[x]=1;
int num=0;
for(int i=0;i<(int)(mp[x].size());++i){
int nto=mp[x][i].to;
if(vis[nto] || fa==nto) continue;
getroot(nto,x);
sz[x]+=sz[nto];
num=max(num,sz[nto]);
}
num=max(num,all-num);
if(num<mx){
mx=num;
root=x;
}
}
il void getdis(int x,int fa){
q[++cnt]=dis[x];
for(int i=0;i<(int)(mp[x].size());++i){
int nto=mp[x][i].to,nw=mp[x][i].w;
if(nto==fa || vis[nto]) continue;
dis[nto]=dis[x]+nw;
getdis(nto,x);
}
}
il void run(ll &rs,ll x,int type){
if(type==1) rs=(rs+x)%mod;
else rs=(rs-x+mod)%mod;
}
il void calc(int x,int v,int type){
cnt=0,dis[x]=v;
getdis(x,0);
ll num0=0,num1=0,num2=0,n0=0,n1=0,n2=0;
for(int i=1;i<=cnt;++i){
if(q[i]%3==0) Add(num0,q[i]),n0++;
if(q[i]%3==1) Add(num1,q[i]),n1++;
if(q[i]%3==2) Add(num2,q[i]),n2++;
}
for(int i=1;i<=cnt;++i){
if(q[i]%3==0){
run(ans0,n0*q[i]%mod+num0,type);
run(ans1,n1*q[i]%mod+num1,type);
run(ans2,n2*q[i]%mod+num2,type);
}
else if(q[i]%3==1){
run(ans0,n2*q[i]%mod+num2,type);
run(ans1,n0*q[i]%mod+num0,type);
run(ans2,n1*q[i]%mod+num1,type);
}
else if(q[i]%3==2){
run(ans0,n1*q[i]%mod+num1,type);
run(ans1,n2*q[i]%mod+num2,type);
run(ans2,n0*q[i]%mod+num0,type);
}
}
}
il void dfs(int x){
calc(x,0,1);
vis[x]=1;
for(int i=0;i<SZ(mp[x]);++i){
int nto=mp[x][i].to,nw=mp[x][i].w;
if(vis[nto]) continue;
calc(nto,nw,-1);
all=sz[x],mx=inf;
getroot(nto,0);
dfs(root);
}
}
int main(){
while(scanf("%d",&n)!=EOF){
all=n,ans0=0,ans1=0,ans2=0;
int x,y,z;
for(int i=1;i<=n-1;++i){
read(x),read(y),read(z);
x++,y++;
mp[x].pb(node{y,z});
mp[y].pb(node{x,z});
}
mx=inf;
getroot(1,0);
dfs(root);
printf("%d %d %d\n",ans0%mod,ans1%mod,ans2%mod);
for(int i=0;i<n+5;++i){
mp[i].clear();vis[i]=0;
}
}
return 0;
}