多項式exp,ln,求逆板子

題目是jzoj 5923
習題:jzoj 6024

#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fd(i,b,a) for(int i=b;i>=a;--i)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define mset(a,x) memset(a,x,sizeof(x))
using namespace std;
typedef long long ll;
const int N=4e5+5,mo=998244353;
ll qmi(ll x,ll n)
{
	int t=1;
	for(x%=mo;n;n>>=1,x=x*x%mo) if(n&1) t=t*x%mo;
	return t;
}
int n;
ll fac[N],ifac[N],inv[N],f[N],g[N];
int lg,len,invlen,bitrev[N];
ll W[N];
void dft(ll *a,int sig=1)
{
	fo(i,0,len-1) if(i<bitrev[i]) swap(a[i],a[bitrev[i]]);
	for(int i=2,gap=len>>1;i<=len;i<<=1,gap>>=1)
		for(int j=0;j<len;j+=i)
			for(ll k=0,*l=a+j,*r=a+j+(i>>1),*w=sig>0?W:W+len;k<(i>>1);++k,++l,++r,w=sig>0?w+gap:w-gap)
			{
				ll u=*l,v=*w * *r%mo;
				*l=(u+v)%mo,*r=(u-v+mo)%mo;
			}
	if(sig==-1) fo(i,0,len-1) a[i]=a[i]*invlen%mo;
}
void nttpre(int len)
{
	invlen=qmi(len,mo-2);
	fo(i,0,len-1) bitrev[i]=(bitrev[i>>1]>>1)|((i&1)<<(lg-1));
	W[0]=1,W[1]=qmi(3,(mo-1)/len);
	fo(i,2,len) W[i]=W[1]*W[i-1]%mo;
}
void polyinv(ll *b,ll *a,int n)//b(x)=1/a(x) mod x^n
{
	if(n==1) {b[0]=qmi(a[0],mo-2);return;}
	polyinv(b,a,(n+1)>>1);
	static ll _a[N],_b[N],_c[N];
	lg=0,len=1;
	while(len<=n+n) len<<=1,++lg;
	fo(i,0,n-1) _a[i]=a[i],_b[i]=b[i];
	fo(i,n,len-1) _a[i]=_b[i]=0;
	nttpre(len);
	dft(_a);
	dft(_b);
	fo(i,0,len-1) _c[i]=_a[i]*_b[i]%mo*_b[i]%mo;
	dft(_c,-1);
	fo(i,0,n-1) b[i]=(2*b[i]-_c[i]+mo)%mo;
	fo(i,n,len-1) b[i]=0;
}
void polyln(ll *b,ll *a,int n)//b(x)=ln(a(x)) mod x^n
{
	static ll a1[N],a2[N];
	fo(i,1,n-1) a1[i-1]=a[i]*i%mo;
	a1[n-1]=0;
	polyinv(a2,a,n);
	
	lg=0,len=1;
	while(len<=n+n) len<<=1,++lg;
	fo(i,n,len-1) a1[i]=a2[i]=0;
	nttpre(len);
	dft(a1);
	dft(a2);
	fo(i,0,len-1) a1[i]=a1[i]*a2[i]%mo;
	dft(a1,-1);
	fo(i,1,n-1) b[i]=a1[i-1]*inv[i]%mo;
	b[0]=0;
}
void polyexp(ll *b,ll *a,int n)//b(x)=exp(a(x)) mod x^n
{
	if(n==1) {b[0]=1;return;}
	polyexp(b,a,(n+1)>>1);
	fo(i,(n+1)>>1,n) b[i]=0;
	static ll c[N];
	polyln(c,b,n);
	fo(i,0,n-1) c[i]=(a[i]-c[i]+mo)%mo;
	c[0]=(c[0]+1)%mo;
	
	lg=0,len=1;
	while(len<=n+n) len<<=1,++lg;
	nttpre(len);
	fo(i,n,len) b[i]=c[i]=0;
	dft(b);
	dft(c);
	fo(i,0,len-1) b[i]=b[i]*c[i]%mo;
	dft(b,-1);
	fo(i,n,len) b[i]=0;
}
ll calc(int k)
{
	static ll a[N],b[N];
	fo(i,1,k) a[i]=g[i];
	fo(i,k+1,n) a[i]=0;
	polyexp(b,a,n+1);
	return b[n]*fac[n]%mo;
}
int main()
{
	freopen("bomb.in","r",stdin);
	freopen("bomb.out","w",stdout);
	int k;
	scanf("%d %d",&n,&k);
	fac[0]=ifac[0]=1;
	fo(i,1,n) fac[i]=fac[i-1]*i%mo,inv[i]=(i==1)?1:(mo-mo/i)*inv[mo%i]%mo,ifac[i]=ifac[i-1]*inv[i]%mo;
	f[0]=1;
	fo(i,1,n) f[i]=qmi(2,1ll*i*(i-1)/2)*ifac[i]%mo;
	polyln(g,f,n+1);
	printf("%lld",(calc(k)-calc(k-1)+mo)%mo);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章