題意是求出樹上有多少條路徑滿足路徑上權值連起來組成的數字是 的倍數。
看到樹上路徑計數就應該知道是點分治。。。
設 爲從 到根的數字表示, 是從根到 的數字表示, 爲 的深度。定義根的深度爲 。
對於一個始於 ,終於 的路徑,可以得到如下公式:
整理一下可以得到:
因此可以通過點分治,每次solve計算經過根且滿足條件的路徑。計算方法就是對於每棵子樹,計算其他子樹對它的貢獻。
複雜度是
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5+7;
using ll = long long;
int n;
ll M, inv10[N];
ll a[N], b[N], c[N], d[N], ans, exp10[N];
int f[N], sz[N], del[N], rt, siz;
struct Edge {
int to,w;
};
vector<Edge> adj[N];
map<ll, int> cnt;
void getroot(int u, int p) {
sz[u]=1;f[u]=0;
for(Edge e : adj[u]) {
int v = e.to;
if(v==p||del[v]) continue;
getroot(v, u);
sz[u]+=sz[v];
f[u] = max(f[u], sz[v]);
}
f[u] = max(f[u], siz-sz[u]);
if(f[u]<f[rt]) rt=u;
}
ll extend_gcd(ll a, ll b, ll &x, ll &y) {
// ax+by=1
ll ans, t;
if(b==0) {
x=1; y=0;
return a;
}
ans = extend_gcd(b, a%b, x, y);
t=x;
x=y;
y=t-(a/b)*y;
return ans;
}
void dfs(int u, int p, int w) {
// a[u] = u到根的數字表示
// b[u] = 根到u的數字表示
// d[u] = u的深度
d[u] = d[p] + 1;
b[u] = (b[p]*10+w)%M;
a[u] = (a[p]+w*exp10[d[u]-1])%M;
c[u] = (M-b[u])%M*inv10[d[u]]%M;
for(Edge e : adj[u]) {
int v = e.to;
int w = e.w;
if(v==p||del[v]) continue;
dfs(v, u, w);
}
}
void update_cnt(int u, int p, int val) {
cnt[c[u]]+=val;
for(Edge e : adj[u]) {
int v = e.to;
if(v==p||del[v]) continue;
update_cnt(v, u, val);
}
}
ll cal(int u, int p) {
ll res = cnt[a[u]];
for(Edge e : adj[u]) {
int v = e.to;
if(v==p||del[v]) continue;
res += cal(v, u);
}
return res;
}
void solve(int u) {
// printf("%d\n", u);
dfs(u, 0, 0);
cnt.clear();
update_cnt(u, 0, 1);
del[u] = 1;
for(Edge e : adj[u]) {
int v = e.to;
if(del[v]) continue;
update_cnt(v, 0, -1);
ll res = cal(v, u);
ans += res;
// printf("u:%d, v:%d, res:%I64d\n", u, v, res);
update_cnt(v, 0, 1);
}
ans += cnt[0]-1; //從重心開始的路徑貢獻
for(Edge e : adj[u]) {
int v = e.to;
if(del[v]) continue;
f[0]=siz=sz[v];
rt=0;
getroot(v, 0);
solve(rt);
}
}
int main() {
scanf("%d%I64d", &n, &M);
ll t, i10;
extend_gcd(10, M, i10, t);
i10 = (i10%M+M)%M;
exp10[0] = 1;
inv10[0] = 1;
for(int i=1; i<N; ++i) {
exp10[i]=exp10[i-1]*10%M;
inv10[i]=inv10[i-1]*i10%M;
}
for(int i=1; i<n; ++i) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
++u; ++v;
adj[u].push_back({v, w});
adj[v].push_back({u, w});
}
d[0]=-1; b[0]=0; a[0]=0;
f[0]=siz=n;
rt=0;
getroot(1, 0);
solve(rt);
printf("%I64d\n", ans);
}