Problem
看題後:
- boshi:這是一道簡單題
- 隊長:這題好像不難,感覺和獵人殺有點像
- 我:
Solution
感覺自己越來越菜了,再這樣下去,要是正式考試送溫暖豈不是連溫暖都拿不到了。。
一臉min-max反演的樣子,由於每個鴿子都等價,枚舉子集大小 即可
其中 來源於平均每 步纔會有一粒餵給選中的鴿子。 表示的是給 只鴿子餵食,有一個鴿子大於等於 時停止的期望步數。
枚舉餵給其他鴿子的玉米粒數量,概率通過方案數來算
其中 表示給 只鴿子喂 粒,且每隻都小於 的方案數。這個可以用生成函數算
時間複雜度
Code
#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long ll;
const int maxn=70010,mod=998244353,G=3;
template <typename Tp> inline int getmin(Tp &x,Tp y){return y<x?x=y,1:0;}
template <typename Tp> inline int getmax(Tp &x,Tp y){return y>x?x=y,1:0;}
template <typename Tp> inline void read(Tp &x)
{
x=0;int f=0;char ch=getchar();
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=1,ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
if(f) x=-x;
}
int n,k,N,l,ans,fac[maxn],inv[maxn],f[55],g[55][maxn],rev[maxn];
int pls(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x<y?x-y+mod:x-y;}
int C(int n,int m){return m>n?0:(ll)fac[n]*inv[m]%mod*inv[n-m]%mod;}
int power(int x,int y)
{
int res=1;
for(;y;y>>=1,x=(ll)x*x%mod)
if(y&1)
res=(ll)res*x%mod;
return res;
}
void NTT(int *a,int f)
{
for(int i=1;i<N;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<N;i<<=1)
{
int gn=power(G,(mod-1)/(i<<1));
for(int j=0;j<N;j+=(i<<1))
{
int g=1,x,y;
for(int k=0;k<i;++k,g=(ll)g*gn%mod)
{
x=a[j+k];y=(ll)g*a[j+k+i]%mod;
a[j+k]=pls(x,y);a[j+k+i]=dec(x,y);
}
}
}
if(f==-1)
{
int iv=power(N,mod-2);reverse(a+1,a+N);
for(int i=0;i<N;i++) a[i]=(ll)a[i]*iv%mod;
}
}
void init(int N)
{
fac[0]=1;
for(int i=1;i<=N;i++) fac[i]=(ll)fac[i-1]*i%mod;
inv[N]=power(fac[N],mod-2);
for(int i=N-1;~i;i--) inv[i]=(ll)inv[i+1]*(i+1)%mod;
}
int main()
{
read(n);read(k);
init(n*k);
for(N=1,l=0;N<=(n*k);N<<=1) ++l;
for(int i=1;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
for(int i=0;i<k;i++) g[1][i]=inv[i];
NTT(g[1],1);g[0][0]=1;
for(int i=2;i<n;i++)
for(int j=0;j<N;j++)
g[i][j]=(ll)g[i-1][j]*g[1][j]%mod;
for(int i=1;i<=1||i<n;i++) NTT(g[i],-1);
for(int i=1;i<=n;i++)
{
int inv=power(i,mod-2);int tmp=power(inv,k-1);
for(int j=0;j<N;j++,tmp=(ll)tmp*inv%mod)
f[i]=(f[i]+(ll)(j+k)*C(j+k-1,j)%mod*g[i-1][j]%mod*fac[j]%mod*tmp)%mod;
}
for(int i=1;i<=n;i++)
{
if(i&1) ans=(ans+(ll)C(n,i)*n%mod*power(i,mod-2)%mod*f[i])%mod;
else ans=dec(ans,(ll)C(n,i)*n%mod*power(i,mod-2)%mod*f[i]%mod)%mod;
}
printf("%d\n",ans);
return 0;
}