鏈接
題解
這題我終究還是被卡常了
但我覺得過沒過不重要,學到方法就行了
這題的做法是,先把所有是的倍數的點拿出來,統計有多少條鏈的長度恰好爲
然後再加上的
然後加上的
然後減去的
…
概括就是
統計鏈的個數用長鏈剖分
注意,假設我現在枚舉的約數是,那麼長鏈剖分是要統計那些所有點都是的倍數的鏈,這條鏈只要含有一個非倍數的點,就不能算進來
所以經過那些非倍數的點的鏈都不能算進來,也就是說這些非倍數的點把整棵樹劃分成若干個不連通的塊,我只需要對每個塊單獨統計就行
所以我其實每次只遍歷了那些權值爲的倍數的點
所以總的複雜度就是每個點權值的約數之和
網上大佬的代碼能通過個測試點,但我的只能通過個…我太菜了
代碼
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 500010
#define maxe 1000010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
int read(int x=0)
{
int c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
int now, vis[maxn], pool[maxn], *f[maxn], tot, a[maxn], D, n;
ll cnt, ans;
vector<int> lis[maxn];
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(auto p=__.head[_];p;p=__.next[p])
}G;
struct EasyMath
{
int prime[maxn], phi[maxn], mu[maxn];
bool mark[maxn];
void shai(int N)
{
int i, j;
for(i=2;i<=N;i++)mark[i]=false;
*prime=0;
phi[1]=mu[1]=1;
for(i=2;i<=N;i++)
{
if(!mark[i])prime[++*prime]=i, mu[i]=-1, phi[i]=i-1;
for(j=1;j<=*prime and i*prime[j]<=N;j++)
{
mark[i*prime[j]]=true;
if(i%prime[j]==0)
{
phi[i*prime[j]]=phi[i]*prime[j];
break;
}
mu[i*prime[j]]=-mu[i];
phi[i*prime[j]]=phi[i]*(prime[j]-1);
}
}
}
}em;
struct Longest_Chain_Decomposition
{
int len[maxn], son[maxn], depth[maxn], istop[maxn];
void dfs(int u, int fa)
{
son[u]=0;
len[u]=1;
depth[u]=depth[fa]+1;
istop[u]=false;
forp(u,G)
{
int v(G.to[p]); if(v==fa or a[v]%now)continue;
dfs(v,u);
if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
}
forp(u,G)
{
int v(G.to[p]); if(v==fa)continue;
if(v!=son[u])istop[v]=true;
}
}
void run(int root)
{
tot=0;
depth[0]=0, dfs(root,0);
istop[root]=true;
}
}lcd;
void dfs(int u, int fa)
{
vis[u]=now;
int i;
if(lcd.istop[u])
{
f[u] = pool + tot;
tot += lcd.len[u];
}
f[u][0]=1;
if(lcd.son[u])
{
int v(lcd.son[u]);
f[v] = f[u] + 1;
dfs(v,u);
}
if(D<lcd.len[u])cnt += f[u][D];
forp(u,G)
{
int v(G.to[p]);
if(v==fa or v==lcd.son[u] or a[v]%now)continue;
dfs(v,u);
rep(i,0,lcd.len[v]-1)if(D-i-1<lcd.len[u])cnt += ll(f[v][i]) * f[u][D-i-1];
rep(i,0,lcd.len[v]-1)f[u][i+1]+=f[v][i];
}
}
vector<int> fac[maxn];
int main()
{
int T=read(), i, j, M=3e4, kase;
em.shai(M);
rep(i,1,M)
{
for(j=1;j*j<=i;j++)
if(i%j==0)
{
fac[i].emb(j);
if(i/j!=j)fac[i].emb(i/j);
}
}
rep(kase,1,T)
{
n = read(), D = read();
rep(i,1,n)a[i]=read();
G.clear(n);
rep(i,1,n-1)
{
int u = read(), v = read();
G.adde(u,v), G.adde(v,u);
}
rep(i,1,M)lis[i].clear();
rep(i,1,n)for(auto d:fac[a[i]])lis[d].emb(i);
ans=0;
rep(now,2,M)
{
if(em.mu[now]==0)continue;
cnt=0;
for(auto x:lis[now])
{
if(vis[x]==now)continue;
tot=0;
lcd.run(x);
dfs(x,0);
}
ans += em.mu[now] * cnt;
}
printf("Case #%d: %lld\n",kase,-ans*2);
}
return 0;
}