FFT/NTT

完全抄襲自 OI-wiki


基本

通俗地說, 係數表達 → 點值表達, 稱爲 DFT, 點值表達 → 係數表達, 稱爲 IDFT。

FFT 通過取某些特殊的 x 的點值來加速 DFT 和 IDFT。

考慮點值表示下的多項式乘法:

\[f(x) = (x_0,f(x_0)),(x_1,f(x_1)),\cdots,(x_n,f(x_n)) \\ g(x) = (x_0,g(x_0)),(x_1,g(x_1)),\cdots,(x_n,g(x_n)) \\ (f\cdot g)(x) = f(x)\cdot g(x) \\ (f\cdot g)(x) = (x_0,f(x_0)g(x_0)), (x_1,f(x_1)g(x_1)), \cdots,(x_n,f(x_n)g(x_n)) \]

明顯是 O(n) 的。

如此,通過 FFT, 可以實現快速的多項式乘法。


分治結構

\[\begin{align} f(x) &= a_0 + a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7 \\ &= (a_0+a_2x^2+a_4x^4+a_6x^6) + x\cdot(a_1+a_3x^2+a_5x^4+a_7x^6) \end{align} \]

設個 \(g(x) = a_0 + a_2x + a_4x^2 + a_6x^3\), 再設個 \(h(x) = a_1 + a_3x + a_5x^2 + a_7x^3\), 就有:

\[f(x) = g(x^2) + x\cdot h(x^2) \]

接下來是精彩的地方, 前面說的特殊點值要發揮作用了。

帶入 n 次的某個單位根, 首先有:

\[\begin{align} f(\omega_n^k) &= g(\omega_n^{2k}) + \omega_n^k \cdot h(\omega_n^{2k}) \\ &= g(\omega_{n/2}^k) + \omega_n^k\cdot h(\omega_{n/2}^k) \end{align} \]

然後有:

\[\begin{align} f(\omega_n^{k + n/2}) &= g(\omega_{n}^{2k+n}) + \omega_{n}^{k+n/2}\cdot h(\omega_n^{2k+n}) \\ &= g(\omega_{n/2}^k) - \omega_n^k\cdot h(\omega_{n/2}^k) \end{align} \]

這個分治的結構就清晰可見了, 雖然本質什麼的還不是很清楚, 但可以窺見一絲構造的痕跡。


IDFT

帶入單位根的共軛複數 DFT 一下, 再把得到的東西除以 n 就行了。


代碼

抄的學長的實現, 是有點優化的寫法。目前沒考慮到封裝。

不用算 rev 的 FFT 真的那麼 dio 嗎?

#include<bits/stdc++.h>

using namespace std;

int rd() {
	int x = 0;
	char c = getchar();
	while(c<'0' || c>'9') c=getchar();
	while(c>='0' && c<='9') x=x*10+c-'0', c=getchar();
	return x;
}

const int N = (1<<21)+ 233;
const double pi = acos(-1);

struct com {
	double x, y;
	com(double a, double b) : x(a), y(b) {
	}
	com() {
		x=y=0;
	}
	const com operator+(const com rhs) const{
		return com(x+rhs.x, y+rhs.y);
	}
	const com operator-(const com rhs) const{
		return com(x-rhs.x, y-rhs.y);
	}
	const com operator*(const com rhs) const{
		return com(x*rhs.x - y*rhs.y, x*rhs.y + y*rhs.x);
	}
};

int n, m, rv[N];
com a[N], b[N];

void fft(com *a, int n, int type) {
	for(int i=0; i<n; ++i) if(i<=rv[i]) swap(a[i], a[rv[i]]);
	for(int m=2; m<=n; m<<=1) {
		com w(cos(2 * pi / m), type * sin(2 * pi / m));
		for(int i=0; i<n; i += m) {
			com tmp = com(1, 0);
			for(int j=0; j<(m>>1); ++j) {
				com p = a[i+j], q = tmp * a[i+j+(m>>1)];
				a[i+j] = p + q,
				a[i+j+(m>>1)] = p - q;
				tmp = tmp * w;
			}
		}
	}
}

int main() {
	n = rd()+1, m = rd() + 1;
	for(int i=0; i<n; ++i) a[i].x = rd();
	for(int i=0; i<m; ++i) b[i].x = rd();
	for(m=n+m-1, n=1; n<m; n=n<<1);
	for(int i=0; i<n; ++i) rv[i] = (rv[i>>1]>>1)|(i&1?(n>>1):0);
	fft(a, n, 1);
	fft(b, n, 1);
	for(int i=0; i<n; ++i) a[i] = a[i] * b[i];
	fft(a, n, -1);
	for(int i=0; i<m; ++i) printf("%d ", (int)(a[i].x/n+0.5));
	return 0;
}

現在正要加速學習多項式, 所以 NTT 就先背個版吧。

#include<bits/stdc++.h>
typedef long long LL;
using namespace std;

const int N = 3e6 + 23, mo = 998244353, g = 3;

int read() { 
    char c = getchar(); int x = 0;
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x;
}

int ksm(int a, int b) {
	int res = 1;
	for(; b; b=b>>1, a=((LL)a*a) % mo)
		if(b & 1) res = (LL)res * a % mo;
	return res;
}

const int ig = ksm(g, mo-2);

int n, m, a[N], b[N], rv[N];

void ntt(int *a, int n, int type) {
	for(int i=0; i<n; ++i) if(i<rv[i]) swap(a[i], a[rv[i]]);
	for(int m=2; m<=n; m<<=1) {
		int w = ksm(type == 1 ? g : ig, (mo-1)/m);
		for(int i = 0; i < n; i += m) {
			int tmp = 1;
			for(int j = 0; j < (m>>1); ++j) {
				int p = a[i+j], q = (LL)tmp * a[i+j+(m>>1)] % mo;
				a[i + j] = (p + q) % mo, a[i + j + (m>>1)] = (p - q + mo) % mo;
				tmp = (LL)tmp * w % mo;
			}
		}
	}
}

int main() {
	n = read()+1, m = read()+1;
	for(int i=0;i<n;++i) a[i]=read();
	for(int i=0;i<m;++i) b[i]=read();
	for(m=n+m-1,n=1;n<m;n<<=1);
	for(int i=0;i<n;++i) rv[i] = (rv[i>>1]>>1)|((i&1)?(n>>1):0);
	ntt(a, n, 1), ntt(b, n, 1);
	for(int i = 0; i < n; ++i) a[i] = (LL)a[i] * b[i] % mo;
	ntt(a, n, -1);
	int inv = ksm(n, mo-2);
	for(int i=0; i<m; ++i) cout << (LL)a[i] * inv % mo << ' ';
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章