題目鏈接:BZOJ2152
題意大致是找樹中多少點對之間的距離%3爲0
題解:可以樹形DP,也可以點分治,兩個代碼都粘過來咯
(上面是樹形DP,下面是點分治)
//樹形DP
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
struct edge{
int to,ne,val;
}e[40005];
int head[20005],tot=0;
int n,ans=0,dp[20005][3];
int gcd(int a,int b)
{
return b==0 ? a : gcd(b,a%b);
}
void dfs(int now,int fa)
{
dp[now][0]++;
for(int i=head[now];i;i=e[i].ne)
{
int v=e[i].to;
if (v==fa) continue;
dfs(v,now);
ans+=dp[now][0]*dp[v][(3-e[i].val)%3];
ans+=dp[now][1]*dp[v][(3+2-e[i].val)%3];
ans+=dp[now][2]*dp[v][(3+1-e[i].val)%3];
for(int j=0;j<=2;j++)
dp[now][(j+e[i].val)%3]+=dp[v][j];
}
}
void push(int x,int y,int val)
{
e[++tot].to=y; e[tot].val=val; e[tot].ne=head[x]; head[x]=tot;
e[++tot].to=x; e[tot].val=val; e[tot].ne=head[y]; head[y]=tot;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
push(x,y,z%3);
}
dfs(1,0);
ans=ans*2+n;
int t=gcd(ans,n*n);
printf("%d/%d",ans/t,n*n/t);
return 0;
}
//點分治(代碼有參考hzwer大神)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int n,tot=0,head[20005],rt,siz[20005],sum,mx[20005],d[20005],t[5],ans=0;
bool vis[20005];
struct edge{
int to,ne,val;
}e[40005];
void push(int x,int y,int val)
{
e[++tot].to=y; e[tot].val=val; e[tot].ne=head[x]; head[x]=tot;
e[++tot].to=x; e[tot].val=val; e[tot].ne=head[y]; head[y]=tot;
}
int gcd(int a,int b)
{
return b==0?a:gcd(b,a%b);
}
void getroot(int now,int fa)
{
siz[now]=1,mx[now]=0;
for (int i=head[now];i;i=e[i].ne)
{
int v=e[i].to;
if (!vis[v]&&v!=fa)
{
getroot(v,now);
siz[now]+=siz[v];
mx[now]=max(mx[now],siz[v]);
}
}
if (mx[now]<sum-siz[now]) mx[now]=sum-siz[now];
if (mx[now]<mx[rt]) rt=now;
}
void dfs(int now,int fa)
{
t[d[now]]++;
for (int i=head[now];i;i=e[i].ne)
{
int v=e[i].to;
if (!vis[v]&&v!=fa)
{
d[v]=(d[now]+e[i].val)%3;
dfs(v,now);
}
}
}
int find(int now,int num)
{
t[0]=t[1]=t[2]=0;
d[now]=num;
dfs(now,0);
return t[0]*t[0]+t[1]*t[2]*2;
}
void solve(int now)
{
ans+=find(now,0);
vis[now]=true;
for (int i=head[now];i;i=e[i].ne)
{
int v=e[i].to;
if (!vis[v])
{
ans-=find(v,e[i].val);
rt=0; sum=siz[v];
getroot(v,now);
solve(rt); //寫的太少,這裏rt寫成v,WA了好久……
}
}
}
int main()
{
n=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read(),z=read();
push(x,y,z%3);
}
mx[0]=sum=n;
getroot(1,0);
solve(rt);
int t=gcd(ans,n*n);
printf("%d/%d",ans/t,n*n/t);
return 0;
}