cf755G. PolandBall and Many Other Balls

一個解法

一句話題意:給出1~n的序列,一個組的定義是1或2個相鄰的數字,求每個數字最多屬於1個組、共1~k個組分別的答案,對998244353取膜。//我語文差你來打我啊

有一個SB的DP算法:dp[i][j]=dp[i-1][j]+dp[i-1][j-1]+dp[i-2][j-1]。其中dp[i][j]表示前i個j組的方案。//lych:這還能不用FFT(NTT)噠

把dp[i]看做多項式,dp[a+b]=dp[a]*dp[b]+dp[a-1]*dp[b-1]*x(霧

這個可以遞歸求解,即|a-b|<=1,和快速冪一樣達到O(logn)。

每層維護2個多項式,dp[i],dp[i+1],手算dp[i+2],就能得到dp[i*2],dp[i*2+1]。沒了。

複雜度O(k log k log n),反正能過。對於FFT的理解之後幾天再分析吧。。

#include<iostream>
#include<cstdio>
#include<algorithm>
#define ll long long
#define P 998244353
#define N 131073
using namespace std;
int n,k,tn,tl,r[N],w[2][N],rn;
int a[N],b[N],c[N],d[N],e[N],f[N],d1[N],e1[N];
int pow(int a,int b,int c)
{
	int ans=1;
	for (;b;a=(ll)a*a%c,b>>=1)
		if (b&1) ans=(ll)ans*a%c;
	return ans;
}

void pre(int x)
{
	tl=0;tn=1;
	while(tn<x)tn<<=1,tl++;
	tn<<=1;tl++;
	int W=pow(3,(P-1)/tn,P);
	w[0][0]=w[1][0]=1;
	for (int i=1;i<tn;i++)
		w[0][i]=(ll)w[0][i-1]*W%P;
	for (int i=1;i<tn;i++)
		w[1][i]=w[0][tn-i];
	for (int i=1;i<tn;i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<(tl-1));
	rn=pow(tn,P-2,P);
}

void dft(int *a,int f)
{
	for (int i=0;i<tn;i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(int i=1;i<tn;i<<=1)
	for(int j=0,t=tn/(i<<1);j<tn;j+=i<<1)
	for(int k=0,l=0;k<i;k++,l+=t)
	{
		int x=(ll)w[f][l]*a[j+k+i]%P;
		int y=a[j+k];
		a[j+k]=(y+x)%P;
		a[j+k+i]=(y+P-x)%P;
	}
	if(f)
		for (int i=0;i<tn;i++)
			a[i]=(ll)a[i]*rn%P;
}
void Get(int a[],int b[],int c[])
{
	c[0]=1;
	for (int i=1;i<=k;i++)
		c[i]=((ll)b[i]+b[i-1]+a[i-1])%P;
	for (int i=k+1;i<tn;i++)
		c[i]=0;
}
void work(int x)
{
	if (x==0)
	{
		a[0]=1;for (int i=1;i<tn;i++)a[i]=0;
		b[0]=1;b[1]=1;for (int i=2;i<tn;i++)b[i]=0;
		return;
	}
	if (x==1)
	{
		a[0]=1;a[1]=1;for (int i=2;i<tn;i++)a[i]=0;
		b[0]=1;b[1]=3;b[2]=1;for (int i=3;i<tn;i++)b[i]=0;
		return;
	}
	work(x/2-1);
	Get(a,b,c);
	for (int i=k+1;i<tn;i++)
		a[i]=b[i]=c[i]=0;
	dft(a,0);
	dft(b,0);
	dft(c,0);
	for (int i=0;i<tn;i++)
	{
		d[i]=(ll)b[i]*b[i]%P;
		d1[i]=(ll)a[i]*a[i]%P;
		e[i]=(ll)b[i]*c[i]%P;
		e1[i]=(ll)a[i]*b[i]%P;
	}
	dft(d,1);
	dft(e,1);
	dft(d1,1);
	dft(e1,1);
	for (int i=1;i<tn;i++)
		d[i]=(d[i]+d1[i-1])%P,e[i]=(e[i]+e1[i-1])%P;
	if (x&1)
	{
		Get(d,e,f);
		for (int i=0;i<tn;i++)
		{
			a[i]=i<=k?e[i]:0;
			b[i]=i<=k?f[i]:0;
		}
	}
	else
	{
		for (int i=0;i<tn;i++)
		{
			a[i]=i<=k?d[i]:0;
			b[i]=i<=k?e[i]:0;
		}
	}
}
int main()
{
	scanf("%d%d",&n,&k);
	pre(max(k+1,4));
	work(n);
	for (int i=1;i<=k;i++)
		printf("%d ",(a[i]+P)%P);
	puts("");
}




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章