HDU - 4812 D Tree (點分治 + 逆元預處理 )

D Tree

 

There is a skyscraping tree standing on the playground of Nanjing University of Science and Technology. On each branch of the tree is an integer (The tree can be treated as a connected graph with N vertices, while each branch can be treated as a vertex). Today the students under the tree are considering a problem: Can we find such a chain on the tree so that the multiplication of all integers on the chain (mod 10 6 + 3) equals to K?
Can you help them in solving this problem?

Input

There are several test cases, please process till EOF.
Each test case starts with a line containing two integers N(1 <= N <= 10 5) and K(0 <=K < 10 6 + 3). The following line contains n numbers v i(1 <= v i < 10 6 + 3), where vi indicates the integer on vertex i. Then follows N - 1 lines. Each line contains two integers x and y, representing an undirected edge between vertex x and vertex y.

Output

For each test case, print a single line containing two integers a and b (where a < b), representing the two endpoints of the chain. If multiply solutions exist, please print the lexicographically smallest one. In case no solution exists, print “No solution”(without quotes) instead.
For more information, please refer to the Sample Output below.

Sample Input

5 60
2 5 2 3 3
1 2
1 3
2 4
2 5
5 2
2 5 2 3 3
1 2
1 3
2 4
2 5

Sample Output

3 4
No solution


        
  

Hint

1. “please print the lexicographically smallest one.”是指: 先按照第一個數字的大小進行比較,若第一個數字大小相同,則按照第二個數字大小進行比較,依次類推。

2. 若出現棧溢出,推薦使用C++語言提交,並通過以下方式擴棧:
#pragma comment(linker,"/STACK:102400000,102400000")

 

代碼:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e5+5;
const int M=1e6+5;
const int mod=1e6+3;
const int inf=1e9+5;
int n,k,root,sum,tot,first[N];
int siz[N],f[N],vis[N],dep[N],d[N];
int inv[M],mp[M],a[N],id[M],ansu,ansv;
struct node
{
    int v,nex;
}e[N<<1];
void add(int u,int v)
{
    e[tot].v=v;
    e[tot].nex=first[u];
    first[u]=tot++;
}
void getroot(int u,int fa)
{
    siz[u]=1,f[u]=0;
    for(int i=first[u];~i;i=e[i].nex)
    {
        int v=e[i].v;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        siz[u]+=siz[v];
        f[u]=max(f[u],siz[v]);
    }
    f[u]=max(f[u],sum-siz[u]);
    if(f[u]<f[root]) root=u;
}
void getdep(int u,int fa)
{
    dep[++dep[0]]=d[u];
    id[dep[0]]=u;
    for(int i=first[u];~i;i=e[i].nex)
    {
        int v=e[i].v;
        if(v==fa||vis[v]) continue;
        d[v]=(ll)d[u]*a[v]%mod;
        getdep(v,u);
    }
}
void update(int x,int u)
{
    int y=(ll)inv[x]*k%mod;
    int v=mp[y];
    if(v==0) return;
    if(u>v) swap(u,v);
    if(u<ansu||(u==ansu&&v<ansv))
        ansu=u,ansv=v;
}
void solve(int u)
{
    vis[u]=1;
    mp[a[u]]=u;
    for(int i=first[u];~i;i=e[i].nex)
    {
        int v=e[i].v;
        if(vis[v]) continue;
        d[v]=a[v],dep[0]=0;
        getdep(v,u);
        for(int j=1;j<=dep[0];j++) update(dep[j],id[j]);
//        d[v]=(ll)a[v]*a[u]%mod,dep[0]=0;
//        getdep(v,u);
        for(int j=1;j<=dep[0];j++)
        {
            int x=(ll)dep[j]*a[u]%mod;
            if(!mp[x]||mp[x]>id[j]) mp[x]=id[j];
        }
    }
    mp[a[u]]=0;
    for(int i=first[u];~i;i=e[i].nex)
    {
        int v=e[i].v;
        if(vis[v]) continue;
        d[v]=(ll)a[v]*a[u]%mod,dep[0]=0;
        getdep(v,u);
        for(int j=1;j<=dep[0];j++) mp[dep[j]]=0;
    }

    for(int i=first[u];~i;i=e[i].nex)
    {
        int v=e[i].v;
        if(vis[v]) continue;
        root=0,sum=siz[v];
        getroot(v,0);
        solve(root);
    }
}
void init()
{
    tot=root=0;
    ansu=ansv=inf;
    memset(first,-1,sizeof(first));
    memset(vis,0,sizeof(vis));
    memset(mp,0,sizeof(mp));
}
int main()
{
    inv[0]=inv[1]=1;
    for(int i=2;i<mod;i++)
        inv[i]=((ll)(mod-mod/i)*inv[mod%i])%mod;
    while(~scanf("%d%d",&n,&k))
    {
        init();
        for(int i=1;i<=n;i++)
            scanf("%d",&a[i]);
        int u,v;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d",&u,&v);
            add(u,v),add(v,u);
        }
        f[0]=inf,root=0,sum=n;
        getroot(1,0);
        solve(root);
        if(ansu==inf) printf("No solution\n");
        else printf("%d %d\n",ansu,ansv);
    }
    return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章