題意:給一棵n個點的數以及n個點的點權,一條路徑(x,y)合法當且僅當從x到y的路徑上所有點權的gcd大於1(一條路徑也可以是一個點),問所有合法路徑經過的點的數量最大的經過了多少點。
題解:就是求最長的合法路徑再+1。這種在樹上詢問路徑的可以優先考慮點分治。每次分治到一個重心,把每條鏈的鏈上gcd分解質因數,去和之前這些質因數對應的最大鏈長加起來再+1,然後這些質因數對應的最大鏈長都更新一下。注意每次先把一棵子樹中的鏈記下來,先詢問再更新(避免同一棵子樹內的兩條鏈更新答案)。
還是要注意先加入重心本身這條鏈(一個點),老是忘這操作......
最終還要考慮一下一個點這種情況,因爲有可能不存在兩個不同的點使得路徑上點權gcd大於1,那麼如果這時某個點自身的點權大於1,答案就是1而非0(樣例1和樣例3對比一下)。
交上去居然1A,之前可是做好了調半個下午的思想準備的......但是這帶了剪枝(一條鏈的gcd已經爲1就不再往下求dis了)都還4000多ms也是跑得巨慢誒...
P.S.好像還可以用樹形dp做?
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<utility>
using namespace std;
const int N=2e5+4;
const int INF=0x3f3f3f3f;
typedef long long ll;
typedef pair<int,int> pp;
int n,a[N],b[N];
int head[N],etot;
struct Edge {
int v,nxt;
}e[N<<1];
int dis[N],siz[N],mx[N];
bool vis[N];
int ans;
int root,s,sum;
pp sta[N];
int tot;
int used[N],cnt;
int val[N];
inline void adde(int v,int u) {
e[++etot].nxt=head[u],e[etot].v=v,head[u]=etot;
}
inline void smax(int &a,int b) {
a=a>b?a:b;
}
inline int gcd(int a,int b) {
return !b?a:gcd(b,a%b);
}
inline void getroot(int p,int fa) {
mx[p]=-INF,siz[p]=1;
for (int i=head[p];~i;i=e[i].nxt) {
int v=e[i].v;
if (vis[v]||v==fa) continue;
getroot(v,p);
siz[p]+=siz[v];
smax(mx[p],siz[v]);
}
smax(mx[p],sum-siz[p]);
if (mx[p]<s) s=mx[p],root=p;
}
inline void getdis(int p,int fa) {
b[p]=gcd(a[p],b[fa]);
if (b[p]==1) return ;
sta[++tot]=make_pair(b[p],dis[p]);
for (int i=head[p];~i;i=e[i].nxt) {
int v=e[i].v;
if (vis[v]||v==fa) continue;
dis[v]=dis[p]+1;
getdis(v,p);
}
}
inline void query(int x,int len) {
for (int i=2;i*i<=x;++i)
if (x%i==0) {
while (x%i==0) x/=i;
if (~val[i]) smax(ans,len+val[i]+1);
}
if (x^1&&~val[x]) smax(ans,len+val[x]+1);
}
inline void add(int x,int len) {
for (int i=2;i*i<=x;++i)
if (x%i==0) {
while (x%i==0) x/=i;
if (val[i]==-1) used[++cnt]=i;
smax(val[i],len);
}
if (x^1) {
if (val[x]==-1) used[++cnt]=x;
smax(val[x],len);
}
}
inline void calc(int p) {
dis[p]=0;
b[p]=a[p];
cnt=0;
add(b[p],0);
for (int i=head[p];~i;i=e[i].nxt) {
int v=e[i].v;
if (vis[v]) continue;
dis[v]=1;
tot=0;
getdis(v,p);
for (int j=1;j<=tot;++j) query(sta[j].first,sta[j].second);
for (int j=1;j<=tot;++j) add(sta[j].first,sta[j].second);
}
for (register int i=1;i<=cnt;++i) val[used[i]]=-1;
}
inline void work(int p) {
vis[p]=true;
if (a[p]^1) calc(p);
for (int i=head[p];~i;i=e[i].nxt) {
int v=e[i].v;
if (vis[v]) continue;
s=INF,sum=siz[v];
getroot(v,0);
work(root);
}
}
inline int read() {
int x=0,f=1;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x*f;
}
int main() {
memset(head,-1,sizeof(head));
memset(vis,false,sizeof(vis));
memset(val,-1,sizeof(val));
n=read();
for (register int i=1;i<=n;++i) a[i]=read();
for (register int i=1;i<n;++i) {
int u=read(),v=read();
adde(u,v);
adde(v,u);
}
s=INF,sum=n;
getroot(1,0);
work(root);
for (register int i=1;i<=n;++i)
if (a[i]>1) smax(ans,1);
printf("%d\n",ans);
return 0;
}