CF755G PolandBall and Many Other Balls

Address

CF755G

Algorithm 1

fi,jf_{i,j} 表示考慮了前 ii 個球選出 jj 組的方案數,轉移顯然爲:
fi,j=fi1,j+fi1,j1+fi2,j1 \begin{aligned} f_{i,j} = f_{i - 1, j} + f_{i - 1, j - 1} + f_{i - 2, j - 1} \\ \end{aligned}
考慮由 fa,,fb,f_{a,*},f_{b,*} 怎麼得到 fa+b,f_{a+b,*},只需要討論是否恰好有包含兩個球的一組從中間跨過,即:
fa+b,j=r=0jfa,rfb,jr+r=0j1fa1,rfb1,jr1 \begin{aligned} f_{a+b,j} = \sum \limits_{r = 0}^{j}f_{a,r}f_{b,j - r} + \sum \limits_{r = 0}^{j - 1}f_{a - 1,r}f_{b - 1, j - r - 1} \end{aligned}
同理我們有:
fa+b1,j=r=0jfa,rfb1,jr+r=0j1fa1,rfb2,jr1=r=0jfa,rfb1,jr+r=0j1fa1,r(fb,jrfb1,jrfb1,jr1) \begin{aligned} f_{a + b - 1, j} &= \sum \limits_{r = 0}^{j}f_{a,r}f_{b - 1,j - r} + \sum \limits_{r = 0}^{j - 1}f_{a - 1, r}f_{b - 2,j - r - 1}\\ &= \sum \limits_{r = 0}^{j}f_{a,r}f_{b - 1,j - r} + \sum \limits_{r = 0}^{j - 1}f_{a - 1, r}(f_{b,j - r} - f_{b - 1, j - r} - f_{b - 1, j - r - 1})\\ \end{aligned}
顯然可以用 NTT 優化,因此求出 nn 的二進制拆分逐項合併即可,時間複雜度 O(klogklogn)\mathcal O(k \log k \log n)

Algorithm 2

fi,jf_{i,j} 的生成函數:
Fi(x)=j=0kfi,jxj \begin{aligned} F_i(x) = \sum \limits_{j = 0}^{k} f_{i,j} x^j \end{aligned}
顯然有如下遞推式:
Fi(x)=(1+x)Fi1(x)+xFi2(x) \begin{aligned} F_i(x) = (1 + x)F_{i - 1}(x) + xF_{i - 2}(x) \end{aligned}
列出特徵方程:
y2(1+x)yxy=0 \begin{aligned} y^2 - (1 + x)y - xy = 0 \end{aligned}
解得:
{y1=1+x+1+6x+x22y2=1+x1+6x+x22 \begin{cases} y_1 = \dfrac{1 + x + \sqrt{1 + 6x + x^2}}{2}\\ y_2 = \dfrac{1 + x - \sqrt{1 + 6x + x^2}}{2}\\ \end{cases}
因此 Fn(x)F_n(x) 可以表示爲:
Fn(x)=ay1n+by2n \begin{aligned} F_n(x) = ay_1^n + by_2^n \end{aligned}
F0(x)=1,F1(x)=1+xF_0(x) = 1, F_1(x) = 1 + x 解得:
{a=y11+6x+x2b=y21+6x+x2 \begin{cases} a = \cfrac{y_1}{\sqrt{1 + 6x + x^2}} \\ b = \cfrac{-y_2}{\sqrt{1 + 6x + x^2}}\\ \end{cases}
代回原式即爲:
Fn(x)=y1n+1y2n+11+6x+x2 \begin{aligned} F_n(x) = \dfrac{y_1^{n + 1} - y_2^{n + 1}}{\sqrt{1 + 6x + x^2}} \\ \end{aligned}
因爲 y2y_2 的常數項爲 00,所以 y2n+10(modxk+1)y_2^{n + 1} \equiv 0 \pmod{x^{k + 1}},可以忽略。

最後的答案爲:
Fn(x)=11+6x+x2exp((n+1)ln1+x+1+6x+x22) \begin{aligned} F_n(x) = \cfrac{1}{\sqrt{1 + 6x + x^2}}\exp\left((n + 1)\ln\cfrac{1 + x + \sqrt{1 + 6x + x^2}}{2}\right) \end{aligned}
對於求 (1+6x+x2)(1 + 6x + x^2) 這樣只有常數項的多項式在模 xk+1x^{k + 1} 意義下的 ww 次冪(ww 爲任意有理數),實際上我們有 O(k)\mathcal O(k) 的做法。

記該多項式爲 G(x)G(x),我們所求即爲:
H(x)=G(x)w \begin{aligned} H(x) = G(x)^w \end{aligned}
對兩邊同時求導,
H(x)=wG(x)w1G(x)H(x)=wH(x)G(x)G(x)G(x)H(x)=wH(x)G(x) \begin{aligned} H'(x) &= w G(x)^{w - 1}G'(x) \\ H'(x) &= w \dfrac{H(x)}{G(x)} G'(x) \\ G(x) H'(x) &= w H(x) G'(x)\\ \end{aligned}
僅針對等式兩邊次數爲 mm 的項前面的係數列式,有:
i=0m(i+1)hi+1gmi=wi=0m(i+1)gi+1hmi(m+1)hm+1g0+i=0m1(i+1)hi+1gmi=wi=0m(i+1)gi+1hmihm+1=wi=0m(i+1)gi+1hmii=0m1(i+1)hi+1gmi(m+1)g0 \begin{aligned} \sum \limits_{i = 0}^{m}(i+1)h_{i + 1}g_{m - i} &= w\sum \limits_{i = 0}^{m}(i +1)g_{i + 1}h_{m - i} \\ (m + 1)h_{m + 1}g_{0} + \sum \limits_{i = 0}^{m - 1}(i + 1)h_{i + 1}g_{m - i} &= w \sum \limits_{i = 0}^{m}(i + 1)g_{i + 1}h_{m - i}\\ h_{m + 1} &= \dfrac{w \sum \limits_{i = 0}^{m}(i + 1)g_{i + 1}h_{m - i} - \sum \limits_{i = 0}^{m - 1}(i + 1)h_{i + 1}g_{m - i}}{(m + 1)g_0}\\ \end{aligned}
因爲 G(x)G(x) 中不爲 00 的項只有 O(1)\mathcal O(1) 個且 h0h_0 已知,我們從小到大枚舉 mm,就可以每次 O(1)\mathcal O(1) 得到 hm+1h_{m + 1}

nn 其實可以出得很大,時間複雜度 O(klogk)\mathcal O(k \log k)

Code

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	char ch;
	while (ch = getchar(), !isdigit(ch));
	res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

typedef long long ll;
const int N = 12e5 + 5;
const int mod = 998244353;
const int inv2 = (mod + 1) / 2;
const int inv3 = (mod + 1) / 3;

ll n;
int K; char s[N];
int inv[N], exp_a[N], rev[N], tw[N];
int _a[N], _b[N], a[N], b[N], c[N], d[N], ln_b[N], inv_a[N];

inline int quick_pow(int x, int k)
{
	int res = 1;
	while (k)
	{
		if (k & 1)
			res = 1ll * res * x % mod;
		x = 1ll * x * x % mod;
		k >>= 1;
	}
	return res;
}

inline void add(int &x, int y)
{
	x += y;
	x >= mod ? x -= mod : 0;
}

inline void dec(int &x, int y)
{
	x -= y;
	x < 0 ? x += mod : 0;
}

inline void NTT(int *f, int fm, int opt)
{
	int g = opt == 1 ? 3 : inv3;
	for (int i = 0; i < fm; ++i)
		if (i < rev[i])
			std::swap(f[i], f[rev[i]]);
	for (int k = 1; k < fm; k <<= 1)
	{
		int w = quick_pow(g, (mod - 1) / (k << 1));
		tw[0] = 1;
		for (int i = 1; i < k; ++i)
			tw[i] = 1ll * tw[i - 1] * w % mod;
		for (int i = 0; i < fm; i += k << 1)
			for (int j = 0, *f1 = f + i, *f2 = f + i + k; j < k; ++j, ++f1, ++f2)
			{	
				int u = *f1, 
					v = 1ll * tw[j] * (*f2) % mod;
				*f1 = *f2 = u;
				add(*f1, v);
				dec(*f2, v);
			}
	}
	if (opt == -1)
	{
		for (int i = 0, inv = quick_pow(fm, mod - 2); i < fm; ++i)
			f[i] = 1ll * f[i] * inv % mod;
	}
}

inline void poly_mul(int *a, int am, int *b, int bm)
{
	int tot = am + bm, fm, k = -1;
	if (tot <= 256)
	{
		for (int i = 0; i <= am; ++i)
			for (int j = 0; j <= bm; ++j)
				_b[i + j] = (1ll * a[i] * b[j] + _b[i + j]) % mod;
		for (int i = 0; i <= tot; ++i)
			a[i] = _b[i], _b[i] = 0;
		for (int i = 0; i <= bm; ++i)
			b[i] = 0;
	}
	else
	{
		for (fm = 1; fm <= tot; fm <<= 1, ++k);
		for (int i = 1; i < fm; ++i)
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k);
		NTT(a, fm, 1);
		NTT(b, fm, 1);
		for (int i = 0; i < fm; ++i)
			a[i] = 1ll * a[i] * b[i] % mod;	
		NTT(a, fm, -1);
		for (int i = 0; i < fm; ++i)
			b[i] = 0;
	}
}

inline void poly_inv(int *a, int am)
{
	inv_a[0] = quick_pow(a[0], mod - 2);
	int k = 1, cnt = 0;
	while (k <= am)
	{
		k <<= 1, ++cnt;
		int fm = k << 1;
		for (int i = 1; i < fm; ++i)
			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
		for (int i = 0; i < k; ++i)
			c[i] = a[i];
		NTT(inv_a, fm, 1); 
		NTT(c, fm, 1);
		for (int i = 0; i < fm; ++i)
		{
			int tmp = inv_a[i];
			add(inv_a[i], tmp);
			dec(inv_a[i], 1ll * c[i] * tmp % mod * tmp % mod);			
		}
		NTT(inv_a, fm, -1);
		for (int i = k; i < fm; ++i)
			inv_a[i] = 0;
		for (int i = 0; i < fm; ++i)
			c[i] = 0;
	}
	for (int i = 0; i <= am; ++i)
		a[i] = inv_a[i];
	for (int i = 0; i < k; ++i)
		inv_a[i] = 0;
}

inline void poly_der(int *a, int am)
{
	for (int i = 0; i < am; ++i)
		a[i] = 1ll * a[i + 1] * (i + 1) % mod;
	a[am] = 0;
}

inline void poly_int(int *a, int am)
{
	for (int i = am + 1; i >= 1; --i)
		a[i] = 1ll * a[i - 1] * inv[i] % mod;	
	a[0] = 0;
}

inline void poly_ln(int *a, int am)
{
	for (int i = 0; i <= am; ++i)
		d[i] = a[i];
	poly_der(a, am);
	poly_inv(d, am);
	poly_mul(a, am - 1, d, am);
	for (int i = am, im = am + am - 1; i <= im; ++i)
		a[i] = 0;
	poly_int(a, am - 1);
}

inline void poly_exp(int *a, int am)
{
	exp_a[0] = 1;
	int k = 1, cnt = 0;
	while (k <= am)
	{
		k <<= 1, ++cnt;
		int fm = k << 1;
		for (int i = 0; i < k; ++i)
			ln_b[i] = exp_a[i];
		poly_ln(ln_b, k - 1);
		for (int i = 0; i < k; ++i)
		{
			int tmp = ln_b[i];
			ln_b[i] = a[i];
			dec(ln_b[i], tmp);
		}
		add(ln_b[0], 1);

		if (fm <= 512)
		{
			for (int i = 0; i < fm; ++i)
				_b[i] = 0;
			for (int i = 0; i < k; ++i)
				for (int j = 0; j < k; ++j)
					_b[i + j] = (1ll * exp_a[i] * ln_b[j] + _b[i + j]) % mod;
			for (int i = 0; i < fm; ++i)
				exp_a[i] = _b[i], _b[i] = 0;
		}
		else
		{
			for (int i = 1; i < fm; ++i)
				rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
			NTT(exp_a, fm, 1);
			NTT(ln_b, fm, 1);
			for (int i = 0; i < fm; ++i)
				exp_a[i] = 1ll * exp_a[i] * ln_b[i] % mod;
			NTT(exp_a, fm, -1);
		}
		for (int i = k; i < fm; ++i)
			exp_a[i] = 0;
		for (int i = 0; i < fm; ++i)
			ln_b[i] = 0;
	}
	for (int i = 0; i <= am; ++i)
		a[i] = exp_a[i];
	for (int i = 0; i < k; ++i)
		exp_a[i] = 0;
}

inline void poly_pow(int *a, int am, int k)
{
	poly_ln(a, am);
	for (int i = 0; i <= am; ++i)
		a[i] = 1ll * a[i] * k % mod;
	poly_exp(a, am);
}

template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}

inline void poly_pow_s(int *a, int am, int *b, int bm, int K)
{
	b[0] = 1;
	for (int i = 0; i < bm; ++i)
	{
		int res = 0;
		for (int j = 0, jm = Min(i, am - 1); j <= jm; ++j)
			res = (1ll * a[j + 1] * (j + 1) % mod * b[i - j] + res) % mod;
		res = 1ll * K * res % mod;
		for (int j = Max(i - 2, 0); j < i; ++j)
			dec(res, 1ll * b[j + 1] * (j + 1) % mod * a[i - j] % mod);
		b[i + 1] = 1ll * res * inv[i + 1] % mod;
	}
}

int main()
{	
	read(n); read(K);

	inv[1] = 1;
	for (int i = 2; i <= K; ++i)
		inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;

	_a[0] = 1, _a[1] = 6, _a[2] = 1;
	poly_pow_s(_a, 2, a, K, inv2);
	poly_pow_s(_a, 2, b, K, mod - inv2);

	add(a[0], 1);
	add(a[1], 1);
	for (int i = 0; i <= K; ++i)
		a[i] = 1ll * a[i] * inv2 % mod;
	poly_pow(a, K, (n + 1) % mod);

	poly_mul(a, K, b, K);
	for (int i = 1; i <= K; ++i)
		put(i <= n ? a[i] : 0), putchar(' ');
	
	fclose(stdin); fclose(stdout);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章