P4705 玩遊戲

有兩個非負整數序列,我們稱其爲\(a_1\cdots a_n\)\(b_1\cdots b_m\)。每次遊戲中玩家會從\(a\)序列和\(b\)序列中分別隨機地抽取一個數,假設抽出的數爲\(a_i,b_j\),則定義這次遊戲的\(k\)次價值爲:\((a_i+b_j)^k\)。要求對於每個\(k\in [1,t]\)\(k\)次價值的期望,對\(998244353\)取模。\(n,m,t\leq 10^5,a_i,b_j\leq 998244352\)


顯然有答案=\(\frac{\sum_{i=1}^n\sum_{j=1}^m(a_i+b_j)^k}{nm}\)。我們只關注分子的式子,把它用二項式定理展開一下就有:

\[\sum_{i=1}^{n}\sum_{j=1}^m\sum_{p=0}^k{k\choose p}a_i^pb_j^{k-p} \\=k!\sum_{p=0}^k\sum_{i=1}^n\frac{a_i^p}{p!}\sum_{j=1}^m\frac{b_j^{k-p}}{(k-p)!} \]

觀察到後面的兩個求和式中分子的指數和分母的階乘是一樣的,那麼可以寫成指數型生成函數的形式,下面以第一個求和式爲例:

\[A(x)=\sum_{p=0}(\sum_{i=1}^na_i^p)\frac{x^p}{p!} \]

那麼答案式子就可以寫成\(k![x^k]A(x)B(x)\)。現在問題就是要計算\(A(x)\)\(B(x)\)的係數。

仍然寫出係數的生成函數,以\(A(x)\)爲例,就有:

\[F_A(x)=\sum_{p=0}(\sum_{i=1}^na_i^p)x^p = \sum_{i=1}^n\sum_{p=0}(a_ix)^p = \sum_{i=1}^n\frac{1}{1-a_ix} \]

怎麼計算最後這個多項式呢?普通的做法就是通分,分母的多項式相乘,分子與分母交叉相乘再相加。但如果從左往右加過去開銷就很大。我們可以用分治的思想,每次把問題折半,先算左邊和右邊的\(\frac{n}{2}\)規模的問題,然後左右兩邊通分加起來。複雜度就是\(O(nlog^2n)\)的。

或者有一個常數更小的做法。注意到\((\ln{(1-a_ix)})'=-\frac{a_i}{1-a_ix}=\frac{1}{1-a_ix}-1\),那麼設\(G_A(x)=\sum_{i=1}^n-\frac{a_i}{1-a_ix}\),就有\(F_A(x)=-xG_A(x)+n\),於是:

\[G_A(x)=\sum_{i=1}^n(\ln{(1-a_ix)})'=(\sum_{i=1}^n\ln{(1-a_ix)})'=(\ln{\prod_{i=1}^n(1-a_ix)})' \]

分治NTT即可。

#include<bits/stdc++.h>
#define rg register
#define il inline
#define cn const
#define gc getchar()
#define fp(i,a,b) for(rg int i=(a),ed=(b);i<=ed;++i)
#define fb(i,a,b) for(rg int i=(a),ed=(b);i>=ed;--i)
#define go(u) for(rg int i=head[u];~i;i=e[i].nxt)
using namespace std;
typedef cn int cint;
typedef long long LL;
il int rd(){
	rg int x(0),f(1); rg char c(gc);
	while(c<'0'||'9'<c){if(c=='-')f=-1;c=gc;}
	while('0'<=c&&c<='9')x=(x<<1)+(x<<3)+(c^48),c=gc;
	return x*f;
}
cint maxn = 1e5+10, mod = 998244353, G = 3, invG = (mod+1)/3, inf = 0x3f3f3f3f;
int n, m, t, a[maxn], b[maxn], fac[maxn], ifac[maxn];
il void sum(int &a, int b){if((a += b) >= mod)a -= mod;}
il int fpow(int a, int b, int ans = 1){
	for(; b; b >>= 1, a = 1ll*a*a%mod)if(b&1)ans = 1ll*ans*a%mod;
	return ans;
}
int rev[maxn<<4];
struct poly{
	vector<int> ary;
	int len;
};
void print(cn poly &a){fp(i, 0, a.len-1)printf("%d ", a.ary[i]);puts("");}
il void crt(poly &po, cint &len){po.ary.resize(len), po.len = len;}
il void ntt(poly &a, cint &f){
	fp(i, 0, a.len-1)if(i < rev[i])swap(a.ary[i], a.ary[rev[i]]);
	for(rg int md = 1; md < a.len; md <<= 1){
		rg int len = md<<1, Gn=fpow(f ? G : invG, (mod-1)/len);
		for(rg int l = 0; l < a.len; l += len){
			rg int Pow = 1;
			for(rg int nw = 0; nw < md; ++nw, Pow = 1ll*Pow*Gn%mod){
				rg int x = a.ary[l+nw], y = 1ll*Pow*a.ary[l+nw+md]%mod;
				a.ary[l+nw] = (x+y)%mod, a.ary[l+nw+md] = (x-y+mod)%mod;
			}
		}
	}
	if(!f){
		rg int inv = fpow(a.len, mod-2);
		fp(i, 0, a.len - 1)a.ary[i] = 1ll*a.ary[i]*inv%mod;
	}
}
il poly mul(cn poly &po1, cn poly &po2, int mx = inf){
	poly a, b, c;
	rg int len = po1.len+po2.len-1, lim = 1, hst = 0;
	while(lim < len)lim <<= 1, ++hst;
	fp(i, 1, lim-1)rev[i] = (rev[i>>1]>>1) | ((i&1)<<hst-1);
	crt(a, lim), crt(b, lim);
	fp(i, 0, po1.len-1)a.ary[i] = po1.ary[i];
	fp(i, 0, po2.len-1)b.ary[i] = po2.ary[i];
	ntt(a, 1), ntt(b, 1);
	fp(i, 0, lim-1)a.ary[i] = 1ll*a.ary[i]*b.ary[i]%mod;
	ntt(a, 0);
	crt(c, min(mx, len));
	fp(i, 0, c.len-1)c.ary[i] = a.ary[i];
	return c;
}
poly divntt(int *a, int l, int r){
	poly res;
	if(l == r){
		crt(res, 2), res.ary[0] = 1, res.ary[1] = mod-a[l];
		return res;
	}
	int md = l+r>>1;
	res = mul(divntt(a, l, md), divntt(a, md+1, r));
	return res;
}
poly getinv(cn poly &a, int n){
	poly res;
	if(n == 1){
		crt(res, 1), res.ary[0] = fpow(a.ary[0], mod-2);
		return res;
	}
	poly f = getinv(a, n+1>>1), g;
	crt(g, n);
	fp(i, 0, n-1)if(i<a.len)g.ary[i] = a.ary[i];else g.ary[i] = 0;
	g = mul(g, f, n);
	fp(i, 0, g.len-1)g.ary[i] = mod-g.ary[i];
	sum(g.ary[0], 2), res = mul(f, g, n);
	return res;
}
il poly der(cn poly &a){
	poly res;
	crt(res, a.len-1);
	fp(i, 1, a.len-1)res.ary[i-1] = 1ll*i*a.ary[i]%mod;
	return res;
}
int main(){
	n = rd(), m = rd();
	fp(i, 1, n)a[i] = rd();
	fp(i, 1, m)b[i] = rd();
	t = rd();
	
	poly GA = divntt(a, 1, n), GB = divntt(b, 1, m);
	poly invGA = getinv(GA, max(n, t)+1), invGB = getinv(GB, max(m, t)+1);
	GA = mul(der(GA), invGA), GB = mul(der(GB), invGB);
	poly FA, FB;
	crt(FA, GA.len+1), crt(FB, GB.len+1);
	fp(i, 0, GA.len-1)FA.ary[i+1] = mod-GA.ary[i];
	fp(i, 0, GB.len-1)FB.ary[i+1] = mod-GB.ary[i];
	FA.ary[0] = n, FB.ary[0] = m;
	
	fac[0] = 1; fp(i, 1, t)fac[i] = 1ll*fac[i-1]*i%mod;
	ifac[t] = fpow(fac[t], mod-2); fb(i, t, 1)ifac[i-1] = 1ll*ifac[i]*i%mod;
	fp(i, 0, FA.len-1)FA.ary[i] = 1ll*FA.ary[i]*ifac[i]%mod;
	fp(i, 0, FB.len-1)FB.ary[i] = 1ll*FB.ary[i]*ifac[i]%mod;
	poly ans = mul(FA, FB);
	rg int tot = 1ll*fpow(n, mod-2)*fpow(m, mod-2)%mod;
	fp(i, 1, t)printf("%lld\n", 1ll*fac[i]*ans.ary[i]%mod*tot%mod);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章