20200501省選模擬賽 a(概率生成函數+推式子)

 

題解

好題,但是這個解法適用範圍比較窄,也沒有多大的用處    ____by   Freopen

我們先把p數組寫成概率生成函數的形式,設

P(x)=\sum_{i=0}^kp_i*x^i

我們發現x^i的係數表示我們走一步到位置 i 的概率是多少

那麼P^2(x)中x^i的係數就表示我們走2步到位置 i 的概率是多少

我們發現答案的概率生成函數(設爲Q(x))就是P^n(x),設

Q(x)=\sum_{i=0}^{nk}q_i*x^i=P^n(x)

最後我們只需要保留一下前面的0~t項(這裏我們替換一下變量,原題用的x與生成函數用的x重複了)

答案就是

\sum_{i=0}^ti*q_i+(1-\sum_{i=0}^tq_i)*t

所以我們有了一個暴力的做法——多項式快速冪O(n*t)

可以用NTT優化一下做到O(t*logn)

然而只有40分(話說這部分分也太少了吧)

 

接下來我們就進入開掛模式(也許這是一個數學中的常用套路吧。。)

(P^{n+1}(x))'=(n+1)P^n(x)*P'(x)(求導的鏈式法則)

(P^{n+1}(x))'=P(x)*(P^n(x))'+P'(x)*P^n(x)(乘法的求導)

於是我們就有了

n*P^n(x)*P'(x)=P(x)*(P^n(x)')

n*Q(x)*P'(x)=Q'(x)*P(x)

然後把它們展開

n\sum_{i=0}^{nk}q_i*x^i*\sum_{j=0}^{k-1}(j+1)p_{j+1}*x^j=\sum_{i=0}^{nk-1}(i+1)q_{i+1}*x^i*\sum_{j=0}^kp_j*x^j

由待定係數法可知,左右同次的x的係數是相等的

設s=i+j,則x^s的係數可以表示爲

n\sum_{j=0}^{k-1}(j+1)*p_{j+1}*q_{s-j+1}=\sum_{j=0}^k(s-j+1)*q_{s-j+1}*p_j

我們把右邊j=0的那一項單獨提出來

n\sum_{j=0}^{k-1}(j+1)*p_{j+1}*q_{s-j+1}=(s+1)*q_{s+1}*p_0+\sum_{j=1}^k(s-j+1)*q_{s-j+1}*p_j

移一下項

p0*(s+1)*q_{s+1}=

n\sum_{j=0}^{k-1}(j+1)*p_{j+1}*q_{s-j+1}-\sum_{j=1}^k(s-j+1)*q_{s-j+1}*p_j

換一下求和的範圍

\sum_{j=1}^{k}n*j*p_j*q_{s-j+1}-(s-j+1)*q_{s-j+1}*p_j

最終得到:

p0*(s+1)*q_{s+1}=\sum_{j=1}^{k}(n*j*-(s-j+1))*q_{s-j+1}*p_j

所以

q_{s+1}=\frac{\sum_{j=1}^k(n*j-(s-j+1))*q_{s-j+1}*p_j}{p0*(s+1)}

令s=s+1

q_s=\frac{\sum_{j=1}^k(n*j-(s-j))*q_{s-j}*p_j}{p0*s}

這樣就可以O(k*t)計算q數組了

代碼:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 100000005
const int mod=998244353;
int inv[N],f[N],p[105];
int ksm(int x,int y)
{
	int ret=1;
	while(y){
		if(y&1)ret=1ll*ret*x%mod;
		y>>=1;x=1ll*x*x%mod;
	}
	return ret;
}
int main()
{
	int n,k,t,i,j,s=0,ans=0;
	int ip0;
	scanf("%d%d%d",&n,&k,&t);
	for(i=0;i<=k;i++){scanf("%d",&p[i]);s+=p[i];}
	s=ksm(s,mod-2);
	for(i=0;i<=k;i++)p[i]=1ll*p[i]*s%mod;
	ip0=ksm(p[0],mod-2);
	f[0]=ksm(p[0],n);
	ans=-1ll*t*f[0]%mod;
	for(i=1;i<=t;i++){
		if(i==1)inv[i]=1;
		else inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
		int sum=0;
		for(j=1;j<=k&&j<=i;j++)
			sum=(1ll*sum+1ll*(1ll*p[j]*j%mod*n%mod-1ll*p[j]*(i-j)%mod+1ll*mod)%mod*f[i-j])%mod;
		f[i]=1ll*sum*ip0%mod*inv[i]%mod;
		ans=(1ll*ans+1ll*f[i]*(i-t))%mod;
	}
	printf("%d\n",((1ll*ans+1ll*t)%mod+mod)%mod);
}

 

 

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章