【Codechef DEVLOCK Devu and Locks】【倍增二維FFT】

題意

求有多少個nn位十進制數(可以有前導零),滿足模pp等於00且每一位數字之和不超過mm
n109,p16,m15000n\le 10^9,p\le 16, m\le 15000

分析

注意到第ii位貢獻的係數爲10imodp10^i\bmod p,兩位的貢獻不同當且僅當對應係數不同。因此可以把數位按照係數分類。設numinum_i表示有多少位滿足貢獻係數爲iinuminum_i可以通過求10k10^kpp的週期來求出。

可以用倍增FFT求出fi,jf_{i,j}表示選了numinum_i0099之間的數,和爲jj的方案數。顯然fi,jf_{i,j}的每種方案中,選取數字乘上對應係數之和模pp的值爲ijmodpij\bmod p。從而得到gi,k,jg_{i,k,j}表示在係數爲ii的位置中,選出來的數乘上係數之和模pp的值爲kk,且選出來的數之和爲jj的方案數。把g0,,gpg_0,\cdots,g_p通過二維FFT合併,就得到答案了。

二維FFT的實現方法是先對AA的第一維做DFT得到數組BB,再對BB的第二維做DFT得到數組CC。則CC就是二維DFT得到的數組。但這題裏面pp比較小,因此第二維可以直接暴力卷積。

倍增FFT的時間複雜度爲O(pmlognlogm)O(pm\log n\log m),二維FFT的時間複雜度爲O(p2mlogm+p3m)O(p^2m\log m+p^3m),因此總的時間複雜度爲O(pmlognlogm+p2mlogm+p3m)O(pm\log n\log m+p^2m\log m+p^3m)

代碼

#include<bits/stdc++.h>
#define pb push_back
using namespace std;

typedef long long LL;

const int N = 33005;
const int P = 55;
const int MOD = 998244353;

int n, p, m, num[P], bz[40][N], f[P][N], g[P][N], tmp[P][N], ans[P][N], L, rev[N];
vector<int> wn1[25], wn2[25];

int gcd(int x, int y)
{
	return !y ? x : gcd(y, x % y);
}

int ksm(int x, int y, int mo)
{
	int ans = 1;
	while (y)
	{
		if (y & 1) ans = (LL)ans * x % mo;
		x = (LL)x * x % mo; y >>= 1;
	}
	return ans;
}

void pre()
{
	int now[p], w = 1 % p, ls;
	memset(now, 0, sizeof(now));
	for (int i = 1; i <= n && !now[w]; i++, w = w * 10 % p) num[w]++, now[w] = i, ls = i;
	int T = ls - now[w] + 1, tmp = n - now[w] + 1 - T;
	for (int i = 0; i < T; i++, w = w * 10 % p) num[w] += tmp / T;
	for (int i = 0; i < tmp % T; i++, w = w * 10 % p) num[w]++;
	int lg = 0;
	for (L = 1; L <= m * 2; L <<= 1, lg++);
	for (int i = 0; i < L; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
	for (int i = 0; i < 20; i++)
	{
		int w1 = ksm(3, (MOD - 1) / (1 << i) / 2, MOD), w2 = ksm(3, MOD - 1 - (MOD - 1) / (1 << i) / 2, MOD);
		wn1[i].pb(1); wn2[i].pb(1);
		for (int j = 1; j < (1 << i); j++) wn1[i].pb((LL)wn1[i][j - 1] * w1 % MOD), wn2[i].pb((LL)wn2[i][j - 1] * w2 % MOD);
	}
}

void NTT(int * a, int f)
{
	for (int i = 0; i < L; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int i = 1, lg = 0; i < L; i <<= 1, lg++)
		for (int j = 0; j < L; j += (i << 1))
			for (int k = 0; k < i; k++)
			{
				int u = a[j + k], v = (LL)a[j + k + i] * (f == 1 ? wn1[lg][k] : wn2[lg][k]) % MOD;
				a[j + k] = (u + v) % MOD; a[j + k + i] = (u + MOD - v) % MOD;
			}
	if (f == -1) for (int i = 0, inv = ksm(L, MOD - 2, MOD); i < L; i++) a[i] = (LL)a[i] * inv % MOD;
}

void solve1()
{
	int mx = 0;
	for (int i = 0; i < p; i++) mx = max(mx, num[i]), f[i][0] = 1;
	for (int i = 0; i <= min(9, m); i++) bz[0][i] = 1;
	NTT(bz[0], 1);
	for (int i = 0; (1 << i) <= mx; i++)
	{
		for (int j = 0; j < p; j++)
			if (num[j] & (1 << i))
			{
				NTT(f[j], 1);
				for (int k = 0; k < L; k++) f[j][k] = (LL)f[j][k] * bz[i][k] % MOD;
				NTT(f[j], -1);
				for (int k = m + 1; k < L; k++) f[j][k] = 0;
			}
		if ((1 << (i + 1)) > mx) break;
		for (int j = 0; j < L; j++) bz[i + 1][j] = (LL)bz[i][j] * bz[i][j] % MOD;
		NTT(bz[i + 1], -1);
		for (int j = m + 1; j < L; j++) bz[i + 1][j] = 0;
		NTT(bz[i + 1], 1);
	}
}

void solve2()
{
	ans[0][0] = 1;
	for (int i = 0; i < p; i++)
	{
		for (int j = 0; j < p; j++) memset(g[j], 0, sizeof(g[j])), memset(tmp[j], 0, sizeof(tmp[j]));
		for (int j = 0; j <= m; j++)
			(g[j * i % p][j] += f[i][j]) %= MOD;
		for (int j = 0; j < p; j++) NTT(ans[j], 1), NTT(g[j], 1);
		for (int j = 0; j < p; j++)
			for (int k = 0, tar = j; k < p; k++, tar = (tar + 1) % p)
				for (int l = 0; l < L; l++)
					(tmp[tar][l] += (LL)ans[j][l] * g[k][l] % MOD) %= MOD;
		for (int j = 0; j < p; j++)
		{
			NTT(tmp[j], -1);
			for (int k = 0; k <= m; k++) ans[j][k] = tmp[j][k];
			for (int k = m + 1; k < L; k++) ans[j][k] = 0;
		}
	}
}

int main()
{
	scanf("%d%d%d", &n, &p, &m);
	pre();
	solve1();
	solve2();
	for (int i = 1; i <= m; i++) (ans[0][i] += ans[0][i - 1]) %= MOD;
	for (int i = 0; i <= m; i++) printf("%d ", ans[0][i]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章