快速傅里葉變換(FFT),就是在O(nlogn)的時間內求出多項式係數矩陣a = (a1, a2, a3...an)的離散傅里葉變換(DFT)矩陣y = (y1, y2...yn)
如果你對傅里葉變換不是很瞭解,不要去網上查博客,打開算法導論的第30章,認認真真讀一遍就懂了!!!!!!!!!
這是以n=4爲例手推的FFT過程
#include<iostream>
#include<sstream>
#include<fstream>
#include<vector>
#include<list>
#include<deque>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cmath>
#include<ctime>
#include<iomanip>
using namespace std;
const double eps(1e-8);
typedef long long lint;
const double PI = acos(-1.0);
struct Complex
{
double real, image;
Complex(double _real, double _image)
{
real = _real;
image = _image;
}
Complex(){}
};
Complex operator + (const Complex &c1, const Complex &c2)
{
return Complex(c1.real + c2.real, c1.image + c2.image);
}
Complex operator - (const Complex &c1, const Complex &c2)
{
return Complex(c1.real - c2.real, c1.image - c2.image);
}
Complex operator * (const Complex &c1, const Complex &c2)
{
return Complex(c1.real*c2.real - c1.image*c2.image, c1.real*c2.image + c1.image*c2.real);
}
int rev(int id, int len)
{
int ret = 0;
for(int i = 0; (1 << i) < len; i++)
{
ret <<= 1;
if(id & (1 << i)) ret |= 1;
}
return ret;
}
Complex A[140000];
void FFT(Complex* a, int len, int DFT)//對a進行DFT或者逆DFT, 結果存在a當中 //對應算法導論p537
{
//Complex* A = new Complex[len]; 這麼寫會爆棧
for(int i = 0; i < len; i++)
A[rev(i, len)] = a[i]; //下面說的二叉樹是算法導論p537最上邊的那顆二叉樹
for(int s = 1; (1 << s) <= len; s++) //樹有幾層
{
int m = (1 << s);
Complex wm = Complex(cos(DFT*2*PI/m), sin(DFT*2*PI/m));
for(int k = 0; k < len; k += m) //每層幾個結點
{
Complex w = Complex(1, 0);
for(int j = 0; j < (m >> 1); j++) //每個結點進行幾次蝴蝶操作
{ //蝴蝶操作
Complex t = w*A[k + j + (m >> 1)];
Complex u = A[k + j];
A[k + j] = u + t;
A[k + j + (m >> 1)] = u - t;
w = w*wm;
}
}
}
if(DFT == -1) for(int i = 0; i < len; i++) A[i].real /= len, A[i].image /= len;
for(int i = 0; i < len; i++) a[i] = A[i];
return;
}
char numA[50010], numB[50010];//以每一位爲係數, 那麼多項式長度不超過50000
Complex a[140000], b[140000];//對應的乘積的長度不會超過100000, 也就是不超過(1 << 17) = 131072
int ans[140000];
int main()
{
while(~scanf("%s", numA))
{
int lenA = strlen(numA);
int sa = 0;
while((1 << sa) < lenA) sa++;
scanf("%s", numB);
int lenB = strlen(numB);
int sb = 0;
while((1 << sb) < lenB) sb++;
//那麼乘積多項式的次數不會超過(1 << (max(sa, sb) + 1))
int len = (1 << (max(sa, sb) + 1));
for(int i = 0; i < len; i++)
{
if(i < lenA) a[i] = Complex(numA[lenA - i - 1] - '0', 0);
else a[i] = Complex(0, 0);
if(i < lenB) b[i] = Complex(numB[lenB - i - 1] - '0', 0);
else b[i] = Complex(0, 0);
}
FFT(a, len, 1);
FFT(b, len, 1);//把A和B換成點值表達
for(int i = 0; i < len; i++)//做點值表達的成乘法
a[i] = a[i]*b[i];
FFT(a, len, -1);//逆DFT換回原來的係數, 虛部一定是0
for(int i = 0; i < len; i++)
ans[i] = (int)(a[i].real + 0.5);//取整誤差的處理
for(int i = 0; i < len - 1; i++)//進位問題
{
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
bool flag = 0;
for(int i = len - 1; i >= 0; i--)//注意輸出格式的調整即可
{
if(ans[i]) printf("%d", ans[i]), flag = 1;
else if(flag || i == 0) printf("0");
}
putchar('\n');
}
return 0;
}