6712. 【2020.06.09省選模擬】題3 sum

題目


正解

推式子題。
比賽時推了半天的生成函數最終推回了一個遞推式?

容斥一下,答案爲i=0N(1)iCNiCSiTM\sum_{i=0}^N(-1)^iC_N^iC_{S-iT}^M
然後就是推式子:
=i=0N(1)iCNi[xM](1+x)SiT=[xM]i=0N(1)iCNi(1+x)(ni)T(1+x)SnT=[xM]((1+x)(ni)T1)N(1+x)SnT=[xMN]((1+x)(ni)T1x)N(1+x)SnT=\sum_{i=0}^N(-1)^iC_N^i[x^M](1+x)^{S-iT}\\ =[x^M]\sum_{i=0}^N(-1)^iC_N^i(1+x)^{(n-i)T}(1+x)^{S-nT} \\ =[x^M]((1+x)^{(n-i)T}-1)^N(1+x)^{S-nT} \\ =[x^{M-N}](\frac{(1+x)^{(n-i)T}-1}{x})^N(1+x)^{S-nT}
裏面的那項和右邊的那項二項式展開,可以快速計算。
然後左邊那項先lnln,乘上指數之後再expexp,即可以得到乘方。


代碼

經過隔壁大佬指點,從此NTT少兩個模法。

using namespace std;
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 524288
#define ll long long
#define mo 998244353
#define mo2 998244353ll*998244353ll
ll qpow(ll x,ll y=mo-2){
	ll r=1;
	for (;y;y>>=1,x=x*x%mo)
		if (y&1)
			r=r*x%mo;
	return r;
}
ll inv[N+10];
ll n,m,s,t;
int nN,re[N];
void setlen(int n){
	int bit=0;
	for (nN=1;nN<2*n;nN<<=1,++bit);
	re[0]=0;
	for (int i=1;i<nN;++i)
		re[i]=re[i>>1]>>1|(i&1)<<bit-1;
}
void clear(ll A[],int n){
	memset(A,0,sizeof(ll)*n);
}
void dft(ll A[],int flag){
	static ll w[N];
	for (int i=0;i<nN;++i)
		if (i<re[i])
			swap(A[i],A[re[i]]);
	for (int i=1;i<nN;i<<=1){
		ll wn=qpow(3,flag==1?(mo-1)/(2*i):mo-1-(mo-1)/(2*i));
		w[0]=1;
		for (int k=1;k<i;++k)
			w[k]=w[k-1]*wn%mo;
		for (int j=0;j<nN;j+=i<<1){
			ll wnk=1;
			for (int k=0;k<i;++k){
				ll x=A[j+k],y=A[j+k+i]*w[k];
				A[j+k]=(x+y)%mo;
				A[j+k+i]=(x-y+mo2)%mo;
			}
		}
	}
	if (flag==-1){
		ll invn=inv[nN];
		for (int i=0;i<nN;++i)
			A[i]=A[i]*invn%mo;
	}
}
void multi(ll c[],ll a[],ll b[],int n){
	static ll A[N],B[N];
	setlen(n);
	clear(A,nN);
	for (int i=0;i<n;++i)
		A[i]=a[i];
	dft(A,1);
	if (a!=b){
		clear(B,nN);
		for (int i=0;i<n;++i)
			B[i]=b[i];
		dft(B,1);
		for (int i=0;i<nN;++i)
			c[i]=A[i]*B[i]%mo;
	}
	else{
		for (int i=0;i<nN;++i)
			c[i]=A[i]*A[i]%mo;
	}
	dft(c,-1);
	for (int i=n;i<nN;++i)
		c[i]=0;
}
void getinv(ll B[],ll A[],int n){
	static ll t[N],t1[N];
	int nn=1;
	for (;nn<n;nn<<=1);
	clear(B,nn);
	B[0]=qpow(A[0]);
	for (int i=2;i<=nn;i<<=1){
		setlen(i);
		clear(t,nN),clear(t1,nN);
		for (int j=0;j<i;++j)
			t[j]=B[j],t1[j]=A[j];
		dft(t,1),dft(t1,1);
		for (int j=0;j<nN;++j)
			t[j]=t[j]*t[j]%mo*t1[j]%mo;
		dft(t,-1);
		for (int j=0;j<i;++j)
			B[j]=(2*B[j]-t[j]+mo)%mo;
	}
	for (int i=n;i<=nn;++i)
		B[i]=0;
}
void getln(ll B[],ll A[],int n){
	static ll A_[N],t[N];
	for (int i=1;i<n;++i)
		A_[i-1]=A[i]*i%mo;
	A_[n-1]=0;
	getinv(t,A,n);
	multi(B,A_,t,n);
	for (int i=n-1;i>=1;--i)
		B[i]=B[i-1]*inv[i]%mo;
	B[0]=0;
}
void getexp(ll B[],ll A[],int n){
	static ll t[N];
	B[0]=1;
	int m=0;
	for (;1<<m<=n;++m){
		getln(t,B,1<<m);
		t[0]=(1+A[0]-t[0]+mo)%mo;
		for (int j=1;j<1<<m;++j)
			t[j]=(A[j]-t[j]+mo)%mo;
		multi(B,B,t,1<<m+1);
	}
	getln(t,B,n);
	t[0]=(1+A[0]-t[0]+mo)%mo;
	for (int j=1;j<n;++j)
		t[j]=(A[j]-t[j]+mo)%mo;
	multi(B,B,t,n);
}
void getpow(ll A[],ll k,int n){
	static ll t[N];
	ll c=0,d=0;
	for (int i=0;i<n;++i)
		if (A[i]){
			c=A[i],d=i;
			break;
		}
	int invc=qpow(c);
	for (int i=d;i<n;++i)
		A[i-d]=A[i]*invc%mo;
	for (int i=n-d;i<n;++i)
		A[i]=0;
	getln(t,A,n);
	k%=mo;
	for (int i=0;i<n;++i)
		t[i]=t[i]*k%mo;
	getexp(A,t,n);
	c=qpow(c,k);
	d*=k;
	for (int i=n-1;i>=d;--i)
		A[i]=A[i-d]*c%mo;
	for (int i=min((ll)n-1,d-1);i>=0;--i)
		A[i]=0;
}
ll F[N],G[N];
void initC(ll F[],ll k,int n){
	ll c=1;
	F[0]=1;
	for (int i=1;i<=n;++i){
		c=c*((k-i+1)%mo)%mo*inv[i]%mo;
		F[i]=c;
	}
}
int main(){
	freopen("sum.in","r",stdin);
	freopen("sum.out","w",stdout);
	inv[0]=1,inv[1]=1;
	for (int i=2;i<=N;++i)
		inv[i]=(ll)(mo-mo/i)*inv[mo%i]%mo;
	scanf("%lld%lld%lld%lld",&s,&t,&n,&m);
//	F[0]=1,F[1]=1;
//	getpow(F,t,m-n+2);
	initC(F,t,m-n+1);
	F[0]--;
	for (int i=0;i<=m-n;++i)
		F[i]=F[i+1];
	F[m-n+1]=0;
	getpow(F,n,m-n+1);
	G[0]=1,G[1]=1;
	initC(G,s-n*t,m-n);
//	getpow(G,s-n*t,m-n+1);
	multi(F,F,G,m-n+1);
	printf("%lld\n",F[m-n]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章