傳送門
無解時統計0有多少對。
有解時要注意統計值是否相同。
dfs時維護的是當前點到根的鏈上的和。
注意快速乘
#include<bits/stdc++.h>
#include<tr1/unordered_map>
#define pll std::pair<ll,ll>
#define mp std::make_pair
#define ll long long
#define re register
#define se second
#define cs const
#define fi first
cs int N=1e5+10;
int n,Head[N],Next[N],V[N],cnt=0;
ll mod,A,B,a[N],inv2,sum,k1,k2;
inline ll add(ll x,ll y){return x+y>=mod?x+y-mod:x+y;}
inline ll dec(ll x,ll y){return x-y<0?x-y+mod:x-y;}
inline ll mul(ll x,ll y){return (x*y-(ll)((long double)x/mod*y)*mod+mod)%mod;}
inline void Mul(ll &x,ll y){x=mul(x,y);}
inline ll quickpow(ll a,ll b,ll ret=1){for(;b;b>>=1,Mul(a,a))if(b&1) Mul(ret,a);return ret;}
inline ll inv(ll x){return quickpow(x,mod-2);}
namespace IO{
cs int Rlen=1<<22|1;
char buf[Rlen],*p1,*p2;
inline char gc(){return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;}
template<typename T>
inline T get(){
char ch=gc();T x=0;
while(!isdigit(ch)) ch=gc();
while(isdigit(ch)) x=((x+(x<<2))<<1)+(ch^48),ch=gc();
return x;
}
inline int gi(){return get<int>();}
inline ll gl(){return get<ll>();}
}
using IO::gi;
using IO::gl;
inline void adde(int u,int v){Next[++cnt]=Head[u],V[cnt]=v,Head[u]=cnt;}
namespace Cipolla{
ll NQ,w;pll ans;
struct plx{
ll x,y;
plx(ll X=0,ll Y=0){x=X,y=Y;}
friend inline plx operator*(cs plx &a,cs plx &b){
return plx(add(mul(a.x,b.x),mul(mul(a.y,b.y),w)),add(mul(a.x,b.y),mul(a.y,b.x)));
}
friend inline plx operator^(plx a,ll b){
plx ret(1,0);
for(;b;b>>=1,a=a*a) if(b&1)
ret=ret*a;
return ret;
}
};
inline pll SQRT(ll X){
if(X==0)return mp(0ll,0ll);
if(quickpow(X,(mod-1)>>1)==mod-1)return mp(-1,-1);
while(1){
NQ=rand(),w=dec(mul(NQ,NQ),X);
if(quickpow(w,(mod-1)>>1)==mod-1)break;
}
ll x_0=(plx(NQ,1ll)^((mod+1)>>1)).x,x_1=mod-x_0;
return mp(x_0,x_1);
}
}
using Cipolla::SQRT;
using Cipolla::ans;
int cnt0=0;
inline void dfs0(int u){
if(a[u]==0) sum+=cnt0,++cnt0;
for(int re i=Head[u],v=V[i];i;v=V[i=Next[i]]) dfs0(v);
if(a[u]==0) --cnt0;
}
std::tr1::unordered_map<ll,int> S;
inline void dfs1(int u){
sum+=S[a[u]];ll val1=mul(a[u],k1),val2=mul(a[u],k2);
(val1==val2)?(++S[val1]):(++S[val1],++S[val2]);
for(int re i=Head[u],v=V[i];i;v=V[i=Next[i]]) dfs1(v);
(val1==val2)?(--S[val1]):(--S[val1],--S[val2]);
}
int main(){
//freopen("3372.in","r",stdin);
srand(time(0));
n=gi(),mod=gl(),A=gl()%mod,B=gl()%mod;
ans=SQRT(dec(mul(A,A),mul(4ll,B))),inv2=inv(2);
k1=mul(inv2,add(mod-A,ans.fi)),k2=mul(inv2,add(mod-A,ans.se));
for(int re i=1;i<=n;++i) a[i]=gl();
for(int re i=2,x;i<=n;++i) x=gi(),adde(x,i);
if(ans.fi==-1) dfs0(1),printf("%lld\n",sum);
else dfs1(1),printf("%lld\n",sum);
}