這道題非常巧妙!!!
我們進行點分治的時候,算出當前子節點的所有子樹中的節點,到當前節點節點的兒子節點的距離,如下圖意思就是
當前節點的紅色節點,我們要求出紅色節點的兒子節點綠色節點,所有綠色的子樹節點的到當綠色的點權乘積
有如下的情況:
1*5*7 3*6*7
2*5*7 4*6*7
然後我們要想辦法查詢其他鏈上到紅色節點的乘積,比如藍色的所有子樹到紅色節點的乘積,以及這些乘積對應的鏈的尾部節點。
因此我們需要用逆元求,因爲我們並不容易直接求出一條鏈上所有節點的點權乘積爲K的鏈,但是我們可以通過搜索出所有當前節點的乘積,然後查詢逆元長度的鏈條是否存在,更加方便的求出答案。
比較抽象。。。多打幾遍就懂了。。。
#pragma comment(linker,"/STACK:102400000,102400000")
#include<iostream>
#include<stdio.h>
#include<algorithm>
#include<string.h>
#define LL long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int maxx = 2e5+6;
const int MOD = 1000003;
int ver[maxx],head[maxx],Next[maxx],q[maxx];
int sz[maxx],mp[MOD+10],vis[maxx],a[maxx],id[maxx];
int inv[MOD+10];
int tot,mx,size,root,l,r,ansx,ansy,k;
inline int read()
{
int x=0;char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x;
}
void add(int u,int v){
ver[++tot]=v;Next[tot]=head[u];head[u]=tot;
ver[++tot]=u;Next[tot]=head[v];head[v]=tot;
}
///求重心
void getroot(int u,int fa){
sz[u]=1;
int num=0;
for (int i=head[u];i;i=Next[i]){
int v=ver[i];
if (v==fa||vis[v])continue;
getroot(v,u);
sz[u]+=sz[v];
num=max(num,sz[v]);
}
num=max(num,size-sz[u]);
if (num<mx)mx=num,root=u;
}
///求子樹的鏈的點權積
void getdis(int u,int fa,int val){
q[++r]=val;
id[r]=u;
for (int i=head[u];i;i=Next[i]){
int v=ver[i];
if (v==fa || vis[v])continue;
getdis(v,u,(LL)val*a[v]%MOD);
}
}
///檢查逆元所對應的長度是否存在
void check(int x,int val){
int w=(LL)inv[val]*k%MOD;
int y=mp[w];
if (y==0||x==y)return;
if (x>y)swap(x,y);
if (x<ansx || (x==ansx && y<ansy)){
ansx=x;
ansy=y;
}
return;
}
void solve(int u){
vis[u]=1;
mp[a[u]]=u;
///求出當前節點的子樹對應的點權積
for (int i=head[u];i;i=Next[i]){
int v=ver[i];
if (vis[v])continue;
r=0;
getdis(v,u,a[v]);
for (int j=1;j<=r;j++){
check(id[j],q[j]);
}
///把所有子樹鏈的乘積再乘上當前節點的權值,
///這樣保存使得另外一顆子樹的一條鏈能夠輕鬆找到另外一條不和自己在同一個子樹內且點權乘積爲K的長度
for (int j=1;j<=r;j++){
q[j]=(LL)q[j]*a[u]%MOD;
int now=mp[q[j]];
if (now==0 || now>id[j]){
mp[q[j]]=id[j];
}
}
}
mp[a[u]]=0;
///要繼續點分治,父親節點的信息以及沒有用了
for (int i=head[u];i;i=Next[i]){
int v=ver[i];
if(vis[v])continue;
r=0;
l=1;
getdis(v,u,(LL)a[u]*a[v]%MOD);
for(int j=1;j<=r;j++){
mp[q[j]]=0;
}
}
for (int i=head[u];i;i=Next[i]){
int v=ver[i];
if (vis[v])continue;
size=sz[v];
mx=INF;
getroot(v,0);
solve(root);
}
}
int main(){
inv[1]=1;
for (int i=2;i<MOD;i++){
inv[i]=(LL)(MOD-(MOD/i))*inv[MOD%i]%MOD;
}
int n;
while(~scanf("%d%d",&n,&k)){
for(int i=1;i<=n;i++){
a[i]=read();
}
tot=0;
memset(mp,0,sizeof(mp));
int u,v;
for (int i=1;i<=n;i++){
vis[i]=0;
head[i]=0;
}
tot=0;
for (int i=1;i<n;i++){
u=read();
v=read();
add(u,v);
}
ansx=INF;
ansy=INF;
mx=INF;
size=n;
getroot(1,0);
solve(root);
if (ansx==INF){
printf("No solution\n");
}else {
printf("%d %d\n",ansx,ansy);
}
}
return 0;
}