FFT/NTT

完全抄袭自 OI-wiki


基本

通俗地说, 系数表达 → 点值表达, 称为 DFT, 点值表达 → 系数表达, 称为 IDFT。

FFT 通过取某些特殊的 x 的点值来加速 DFT 和 IDFT。

考虑点值表示下的多项式乘法:

\[f(x) = (x_0,f(x_0)),(x_1,f(x_1)),\cdots,(x_n,f(x_n)) \\ g(x) = (x_0,g(x_0)),(x_1,g(x_1)),\cdots,(x_n,g(x_n)) \\ (f\cdot g)(x) = f(x)\cdot g(x) \\ (f\cdot g)(x) = (x_0,f(x_0)g(x_0)), (x_1,f(x_1)g(x_1)), \cdots,(x_n,f(x_n)g(x_n)) \]

明显是 O(n) 的。

如此,通过 FFT, 可以实现快速的多项式乘法。


分治结构

\[\begin{align} f(x) &= a_0 + a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7 \\ &= (a_0+a_2x^2+a_4x^4+a_6x^6) + x\cdot(a_1+a_3x^2+a_5x^4+a_7x^6) \end{align} \]

设个 \(g(x) = a_0 + a_2x + a_4x^2 + a_6x^3\), 再设个 \(h(x) = a_1 + a_3x + a_5x^2 + a_7x^3\), 就有:

\[f(x) = g(x^2) + x\cdot h(x^2) \]

接下来是精彩的地方, 前面说的特殊点值要发挥作用了。

带入 n 次的某个单位根, 首先有:

\[\begin{align} f(\omega_n^k) &= g(\omega_n^{2k}) + \omega_n^k \cdot h(\omega_n^{2k}) \\ &= g(\omega_{n/2}^k) + \omega_n^k\cdot h(\omega_{n/2}^k) \end{align} \]

然后有:

\[\begin{align} f(\omega_n^{k + n/2}) &= g(\omega_{n}^{2k+n}) + \omega_{n}^{k+n/2}\cdot h(\omega_n^{2k+n}) \\ &= g(\omega_{n/2}^k) - \omega_n^k\cdot h(\omega_{n/2}^k) \end{align} \]

这个分治的结构就清晰可见了, 虽然本质什么的还不是很清楚, 但可以窥见一丝构造的痕迹。


IDFT

带入单位根的共轭复数 DFT 一下, 再把得到的东西除以 n 就行了。


代码

抄的学长的实现, 是有点优化的写法。目前没考虑到封装。

不用算 rev 的 FFT 真的那么 dio 吗?

#include<bits/stdc++.h>

using namespace std;

int rd() {
	int x = 0;
	char c = getchar();
	while(c<'0' || c>'9') c=getchar();
	while(c>='0' && c<='9') x=x*10+c-'0', c=getchar();
	return x;
}

const int N = (1<<21)+ 233;
const double pi = acos(-1);

struct com {
	double x, y;
	com(double a, double b) : x(a), y(b) {
	}
	com() {
		x=y=0;
	}
	const com operator+(const com rhs) const{
		return com(x+rhs.x, y+rhs.y);
	}
	const com operator-(const com rhs) const{
		return com(x-rhs.x, y-rhs.y);
	}
	const com operator*(const com rhs) const{
		return com(x*rhs.x - y*rhs.y, x*rhs.y + y*rhs.x);
	}
};

int n, m, rv[N];
com a[N], b[N];

void fft(com *a, int n, int type) {
	for(int i=0; i<n; ++i) if(i<=rv[i]) swap(a[i], a[rv[i]]);
	for(int m=2; m<=n; m<<=1) {
		com w(cos(2 * pi / m), type * sin(2 * pi / m));
		for(int i=0; i<n; i += m) {
			com tmp = com(1, 0);
			for(int j=0; j<(m>>1); ++j) {
				com p = a[i+j], q = tmp * a[i+j+(m>>1)];
				a[i+j] = p + q,
				a[i+j+(m>>1)] = p - q;
				tmp = tmp * w;
			}
		}
	}
}

int main() {
	n = rd()+1, m = rd() + 1;
	for(int i=0; i<n; ++i) a[i].x = rd();
	for(int i=0; i<m; ++i) b[i].x = rd();
	for(m=n+m-1, n=1; n<m; n=n<<1);
	for(int i=0; i<n; ++i) rv[i] = (rv[i>>1]>>1)|(i&1?(n>>1):0);
	fft(a, n, 1);
	fft(b, n, 1);
	for(int i=0; i<n; ++i) a[i] = a[i] * b[i];
	fft(a, n, -1);
	for(int i=0; i<m; ++i) printf("%d ", (int)(a[i].x/n+0.5));
	return 0;
}

现在正要加速学习多项式, 所以 NTT 就先背个版吧。

#include<bits/stdc++.h>
typedef long long LL;
using namespace std;

const int N = 3e6 + 23, mo = 998244353, g = 3;

int read() { 
    char c = getchar(); int x = 0;
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x;
}

int ksm(int a, int b) {
	int res = 1;
	for(; b; b=b>>1, a=((LL)a*a) % mo)
		if(b & 1) res = (LL)res * a % mo;
	return res;
}

const int ig = ksm(g, mo-2);

int n, m, a[N], b[N], rv[N];

void ntt(int *a, int n, int type) {
	for(int i=0; i<n; ++i) if(i<rv[i]) swap(a[i], a[rv[i]]);
	for(int m=2; m<=n; m<<=1) {
		int w = ksm(type == 1 ? g : ig, (mo-1)/m);
		for(int i = 0; i < n; i += m) {
			int tmp = 1;
			for(int j = 0; j < (m>>1); ++j) {
				int p = a[i+j], q = (LL)tmp * a[i+j+(m>>1)] % mo;
				a[i + j] = (p + q) % mo, a[i + j + (m>>1)] = (p - q + mo) % mo;
				tmp = (LL)tmp * w % mo;
			}
		}
	}
}

int main() {
	n = read()+1, m = read()+1;
	for(int i=0;i<n;++i) a[i]=read();
	for(int i=0;i<m;++i) b[i]=read();
	for(m=n+m-1,n=1;n<m;n<<=1);
	for(int i=0;i<n;++i) rv[i] = (rv[i>>1]>>1)|((i&1)?(n>>1):0);
	ntt(a, n, 1), ntt(b, n, 1);
	for(int i = 0; i < n; ++i) a[i] = (LL)a[i] * b[i] % mo;
	ntt(a, n, -1);
	int inv = ksm(n, mo-2);
	for(int i=0; i<m; ++i) cout << (LL)a[i] * inv % mo << ' ';
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章