SCOI2018 Numazu 的蜜柑

傳送門
無解時統計0有多少對。
有解時要注意統計值是否相同。
dfs時維護的是當前點到根的鏈上的和。
注意快速乘

#include<bits/stdc++.h>
#include<tr1/unordered_map>
#define pll std::pair<ll,ll>
#define mp std::make_pair
#define ll long long
#define re register
#define se second
#define cs const
#define fi first

cs int N=1e5+10;
int n,Head[N],Next[N],V[N],cnt=0;
ll mod,A,B,a[N],inv2,sum,k1,k2;
inline ll add(ll x,ll y){return x+y>=mod?x+y-mod:x+y;}
inline ll dec(ll x,ll y){return x-y<0?x-y+mod:x-y;}
inline ll mul(ll x,ll y){return (x*y-(ll)((long double)x/mod*y)*mod+mod)%mod;}
inline void Mul(ll &x,ll y){x=mul(x,y);}

inline ll quickpow(ll a,ll b,ll ret=1){for(;b;b>>=1,Mul(a,a))if(b&1) Mul(ret,a);return ret;}
inline ll inv(ll x){return quickpow(x,mod-2);}

namespace IO{
	cs int Rlen=1<<22|1;
	char buf[Rlen],*p1,*p2;
	inline char gc(){return (p1==p2)&&(p2=(p1=buf)+fread(buf,1,Rlen,stdin),p1==p2)?EOF:*p1++;}
	template<typename T>
	inline T get(){
		char ch=gc();T x=0;
		while(!isdigit(ch)) ch=gc();
		while(isdigit(ch)) x=((x+(x<<2))<<1)+(ch^48),ch=gc();
		return x;
	}
	inline int gi(){return get<int>();}
	inline ll gl(){return get<ll>();}
}
using IO::gi;
using IO::gl;
inline void adde(int u,int v){Next[++cnt]=Head[u],V[cnt]=v,Head[u]=cnt;}

namespace Cipolla{
	ll NQ,w;pll ans;
	struct plx{
		ll x,y;
		plx(ll X=0,ll Y=0){x=X,y=Y;}
		friend inline plx operator*(cs plx &a,cs plx &b){
			return plx(add(mul(a.x,b.x),mul(mul(a.y,b.y),w)),add(mul(a.x,b.y),mul(a.y,b.x)));
		}
		friend inline plx operator^(plx a,ll b){
			plx ret(1,0);
			for(;b;b>>=1,a=a*a) if(b&1)
				ret=ret*a;
			return ret;
		}
	};
	inline pll SQRT(ll X){
		if(X==0)return mp(0ll,0ll);
		if(quickpow(X,(mod-1)>>1)==mod-1)return mp(-1,-1);
		while(1){
			NQ=rand(),w=dec(mul(NQ,NQ),X);
			if(quickpow(w,(mod-1)>>1)==mod-1)break;
		}
		ll x_0=(plx(NQ,1ll)^((mod+1)>>1)).x,x_1=mod-x_0;
		return mp(x_0,x_1);
	}
}
using Cipolla::SQRT;
using Cipolla::ans;

int cnt0=0;
inline void dfs0(int u){
	if(a[u]==0) sum+=cnt0,++cnt0;
	for(int re i=Head[u],v=V[i];i;v=V[i=Next[i]]) dfs0(v);
	if(a[u]==0) --cnt0;
}

std::tr1::unordered_map<ll,int> S;
inline void dfs1(int u){
	sum+=S[a[u]];ll val1=mul(a[u],k1),val2=mul(a[u],k2);
	(val1==val2)?(++S[val1]):(++S[val1],++S[val2]);
	for(int re i=Head[u],v=V[i];i;v=V[i=Next[i]]) dfs1(v);
	(val1==val2)?(--S[val1]):(--S[val1],--S[val2]);
}

int main(){
	//freopen("3372.in","r",stdin);
	srand(time(0));
	n=gi(),mod=gl(),A=gl()%mod,B=gl()%mod;
	ans=SQRT(dec(mul(A,A),mul(4ll,B))),inv2=inv(2);
	k1=mul(inv2,add(mod-A,ans.fi)),k2=mul(inv2,add(mod-A,ans.se));
	for(int re i=1;i<=n;++i) a[i]=gl();
	for(int re i=2,x;i<=n;++i) x=gi(),adde(x,i);
	if(ans.fi==-1) dfs0(1),printf("%lld\n",sum);
	else  dfs1(1),printf("%lld\n",sum);
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章