快速傅里葉變換(FFT)

  首先說一下我用FFT做什麼,我要做的是多項式乘法,或者說,加速多項式乘法。
  考慮多項式A(x)=j=0n1ajxj ,它一共有n 項,我們稱它的次數界爲n 。假設我們有兩個次數界爲n 的多項式A(x)B(x) ,要求它們的和是非常簡單的,只需要將對應的係數相加,複雜度爲O(n) 。如果要求他們的積,則需要將A(x) 的每一項和B(x) 的每一項相乘,複雜度爲O(n2) ,這就顯得有點慢了。
  上面提到的多項式表示方法A(x)=j=0n1ajxj 稱爲係數表示,實際上它還有另一種表示方法叫點值表示。我們取n 個不同的值x0x1 ,…,xn1 代入多項式,可以得到n 個點(x0,A(x0)) ,…,(xn1,A(xn1)) 。這n 個點可以確定一個唯一的原多項式(至於爲什麼,就不細說了)。
  假設我們可以迅速在多項式的係數表示和點值表示間轉換,就可以迅速完成多項式乘法。回到那個多項式A(x) ,我們把x 的取值範圍擴展到複數,ωn 表示n (假設爲2的整數冪)次單位複數根,恰當地選擇多項式中x 的取值,使其分別等於ω1n,ω2n,...,ωnn ,就可以利用一些性質,遞歸地計算,進行加速。

下面放出一道題:Thief in a Shop
  可以這樣理解題意:有n 個數a1 ~an ,從中選取k 個(可重複),問得到的和可以是多少。我們可以把它轉化爲一個多項式,若存在ai ,則多項式第ai 項的係數爲1 ,否則爲0 。這樣一來,兩個多項式的乘積中,係數不爲0 的項就是k=2 時的解,繼續乘下去,可以得到k=3,4,5... 的解。
  於是,我們就需要FFT來迅速完成多項式乘法,同時利用倍增,只進行log2(k) 次乘法。注意每次乘完之後,要對那個列向量規整一下,避免迭代過程累積誤差。
  下面的代碼可以作爲模板。。

#include <bits/stdc++.h>

using namespace std;

#define ll long long

const double eps = 0.5;
const double PI = acos(-1.0);

struct Complex{
    double r,i;
    Complex(double r=0.0,double i=0.0):r(r),i(i){
    }

    Complex operator+(const Complex& c)const{
        return Complex(r+c.r,i+c.i);
    }

    Complex operator-(const Complex& c)const{
        return Complex(r-c.r,i-c.i);
    }

    Complex operator*(const Complex& c)const{
        return Complex(r*c.r-i*c.i,r*c.i+i*c.r);
    }
};

void change(Complex y[],int len){
    for(int i=1,j=len>>1;i<len-1;i++){
        if(i<j){
            swap(y[i],y[j]);
        }
        int k = len>>1;
        while(j>=k){
            j -= k;
            k >>= 1;
        }
        if(j<k){
            j += k;
        }
    }
}

void fft(Complex y[],int len,int on){
    change(y,len);
    for(int i=2;i<=len;i<<=1){
        Complex wn(cos(-on*2*PI/i),sin(-on*2*PI/i));
        for(int j=0;j<len;j+=i){
            Complex w(1,0);
            for(int k=j;k<j+i/2;k++){
                Complex u = y[k];
                Complex t = w*y[k+i/2];
                y[k] = u + t;
                y[k+i/2] = u - t;
                w = w * wn;
            }
        }
    }
    if(on == -1){
        for(int i=0;i<len;i++){
            y[i].r /= len;
        }
    }
}

int lowbit(int x){
    return x&(-x);
}

int fix(Complex *y,int l){
    while(l && y[l-1].r<eps ){
        l--;
    }
    for(int i=0;i<l;i++){
        Complex &c = y[i];
        if(c.r>eps){
            c.r = 1;
        }else{
            c.r = 0;
        }
        c.i = 0;
    }
    for(int i=l;i<1024*1024;i++){
        Complex &c = y[i];
        c.r = c.i = 0;
    } 
    return l+1;
}

void Print(Complex *y,int l){
    for(int i=0;i<l;i++){
        if(y[i].r>eps){
            cout<<i<<" ";
        }
    }
    cout<<endl;
}

int mul(Complex *v1,int l1,Complex *v2,int l2,Complex *res){
    l1 = fix(v1,l1);
    l2 = fix(v2,l2);

    int sz = 2*max(l1,l2);

    while(sz!=lowbit(sz)){
        sz+=lowbit(sz);
    }

    l1 = l2 = sz;

    fft(v1,l1,1);
    fft(v2,l2,1);

    for(int i=0;i<sz;i++){
        res[i] = (v1[i]*v2[i]);
    }

    fft(res,sz,-1);

    return sz;
}

Complex v[1024*1024];
Complex v2[1024*1024];
Complex ans[1024*1024];

int main(){ 
    int n,k;
    cin>>n>>k;

    for(int i=1;i<=n;i++){
        int num;
        cin>>num;
        v[num].r = 1.0;
    }

    int sz = 1024;
    ans[0].r = 1;
    while(k){
        if(k&1){
            for(int i=0;i<sz;i++){
                v2[i] = v[i];
            }
            mul(v2,sz,ans,sz,ans);      
        }

        for(int i=0;i<sz;i++){
            v2[i] = v[i];
        }
        sz = mul(v,sz,v2,sz,v);

        k>>=1;
    }
    sz = fix(ans,sz);

    for(int i=0;i<sz;i++){
        if(ans[i].r > eps){
            printf("%d ",i);
        }
    }
    return 0;
}
發佈了455 篇原創文章 · 獲贊 16 · 訪問量 22萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章