基本
通俗地說, 係數表達 → 點值表達, 稱爲 DFT, 點值表達 → 係數表達, 稱爲 IDFT。
FFT 通過取某些特殊的 x 的點值來加速 DFT 和 IDFT。
考慮點值表示下的多項式乘法:
明顯是 O(n) 的。
如此,通過 FFT, 可以實現快速的多項式乘法。
分治結構
設個 \(g(x) = a_0 + a_2x + a_4x^2 + a_6x^3\), 再設個 \(h(x) = a_1 + a_3x + a_5x^2 + a_7x^3\), 就有:
接下來是精彩的地方, 前面說的特殊點值要發揮作用了。
帶入 n 次的某個單位根, 首先有:
然後有:
這個分治的結構就清晰可見了, 雖然本質什麼的還不是很清楚, 但可以窺見一絲構造的痕跡。
IDFT
帶入單位根的共軛複數 DFT 一下, 再把得到的東西除以 n 就行了。
代碼
抄的學長的實現, 是有點優化的寫法。目前沒考慮到封裝。
不用算 rev 的 FFT 真的那麼 dio 嗎?
#include<bits/stdc++.h>
using namespace std;
int rd() {
int x = 0;
char c = getchar();
while(c<'0' || c>'9') c=getchar();
while(c>='0' && c<='9') x=x*10+c-'0', c=getchar();
return x;
}
const int N = (1<<21)+ 233;
const double pi = acos(-1);
struct com {
double x, y;
com(double a, double b) : x(a), y(b) {
}
com() {
x=y=0;
}
const com operator+(const com rhs) const{
return com(x+rhs.x, y+rhs.y);
}
const com operator-(const com rhs) const{
return com(x-rhs.x, y-rhs.y);
}
const com operator*(const com rhs) const{
return com(x*rhs.x - y*rhs.y, x*rhs.y + y*rhs.x);
}
};
int n, m, rv[N];
com a[N], b[N];
void fft(com *a, int n, int type) {
for(int i=0; i<n; ++i) if(i<=rv[i]) swap(a[i], a[rv[i]]);
for(int m=2; m<=n; m<<=1) {
com w(cos(2 * pi / m), type * sin(2 * pi / m));
for(int i=0; i<n; i += m) {
com tmp = com(1, 0);
for(int j=0; j<(m>>1); ++j) {
com p = a[i+j], q = tmp * a[i+j+(m>>1)];
a[i+j] = p + q,
a[i+j+(m>>1)] = p - q;
tmp = tmp * w;
}
}
}
}
int main() {
n = rd()+1, m = rd() + 1;
for(int i=0; i<n; ++i) a[i].x = rd();
for(int i=0; i<m; ++i) b[i].x = rd();
for(m=n+m-1, n=1; n<m; n=n<<1);
for(int i=0; i<n; ++i) rv[i] = (rv[i>>1]>>1)|(i&1?(n>>1):0);
fft(a, n, 1);
fft(b, n, 1);
for(int i=0; i<n; ++i) a[i] = a[i] * b[i];
fft(a, n, -1);
for(int i=0; i<m; ++i) printf("%d ", (int)(a[i].x/n+0.5));
return 0;
}
現在正要加速學習多項式, 所以 NTT 就先背個版吧。
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
const int N = 3e6 + 23, mo = 998244353, g = 3;
int read() {
char c = getchar(); int x = 0;
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x;
}
int ksm(int a, int b) {
int res = 1;
for(; b; b=b>>1, a=((LL)a*a) % mo)
if(b & 1) res = (LL)res * a % mo;
return res;
}
const int ig = ksm(g, mo-2);
int n, m, a[N], b[N], rv[N];
void ntt(int *a, int n, int type) {
for(int i=0; i<n; ++i) if(i<rv[i]) swap(a[i], a[rv[i]]);
for(int m=2; m<=n; m<<=1) {
int w = ksm(type == 1 ? g : ig, (mo-1)/m);
for(int i = 0; i < n; i += m) {
int tmp = 1;
for(int j = 0; j < (m>>1); ++j) {
int p = a[i+j], q = (LL)tmp * a[i+j+(m>>1)] % mo;
a[i + j] = (p + q) % mo, a[i + j + (m>>1)] = (p - q + mo) % mo;
tmp = (LL)tmp * w % mo;
}
}
}
}
int main() {
n = read()+1, m = read()+1;
for(int i=0;i<n;++i) a[i]=read();
for(int i=0;i<m;++i) b[i]=read();
for(m=n+m-1,n=1;n<m;n<<=1);
for(int i=0;i<n;++i) rv[i] = (rv[i>>1]>>1)|((i&1)?(n>>1):0);
ntt(a, n, 1), ntt(b, n, 1);
for(int i = 0; i < n; ++i) a[i] = (LL)a[i] * b[i] % mo;
ntt(a, n, -1);
int inv = ksm(n, mo-2);
for(int i=0; i<m; ++i) cout << (LL)a[i] * inv % mo << ' ';
return 0;
}