多項式乘法優化 學習筆記

今早重新看了myy的論文,又掌握了一些多項式乘法的新姿勢,於是寫一篇blog鞏固一下QAQ。


①如何用一次DFT加一次IDFT求出兩個實序列A和B的卷積?

這裏我們只要求卷積後的結果,不需要求DFT的值,所以有一種很簡便的方法:令複數序列C的實部爲A,虛部爲B。將其自卷,所得結果虛部的值除以2就是要求的多項式。

這個十分容易證明:

C2[k]=j=0kC[j]C[kj]

=j=0k(A[j]+iB[j])(A[kj]+iB[kj])

=j=0k(A[j]A[kj]B[j]B[kj])+ij=0k(A[j]B[kj]+B[j]A[kj])

②如何用一次DFT同時求出兩個實序列在單位複數根處的點值?

這個推導就很複雜了,大概就是一堆三角函數和i 換來換去,具體要看myy的論文,我也不再贅述。最終概括一下做法,就是設:

P(x)=A(x)+iB(x)

Q(x)=A(x)iB(x)

P(x),Q(x) 在單位複數根處的點值表達分別爲FP,FQ ,則可以證明FP(k)FQ(Nk) 互爲共軛複數。因此只需要對P(x) 進行DFT即可。然後會有:

DFTA=FP+FQ2

DFTB=FPFQ2i

(myy論文裏第二條式子寫的是乘以i ,不過我覺得是除以i 纔對)


③關於單位複數根ωNk

這個有三種算法。一種是累乘,時間快但精度不高。還有一種是直接每次用2πkN 的三角函數去算,這樣比較慢,但精度高。最後一種是預處理,然後用vector之類的存下來。第三種方法比較折中,在拆係數+FFT處理任意模數多項式卷積的時候經常用。


貼一份模板題(洛谷P3803)的CODE:
因爲只用了兩次DFT,所以完全不怕卡常

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;

const int maxn=4000000;
const double pi=acos(-1.0);

struct Complex
{
    double X,Y;
    Complex (double a=0.0,double b=0.0) : X(a),Y(b) {}
} ;

Complex operator+(Complex a,Complex b){return Complex(a.X+b.X,a.Y+b.Y);}
Complex operator-(Complex a,Complex b){return Complex(a.X-b.X,a.Y-b.Y);}
Complex operator*(Complex a,Complex b){return Complex(a.X*b.X-a.Y*b.Y,a.X*b.Y+a.Y*b.X);}

Complex A[maxn];
Complex B[maxn];

vector <Complex> w[maxn];
int Rev[maxn];
int N,Lg;

int F[maxn];
int G[maxn];
int n,m;

void DFT(Complex *a,double f)
{
    for (int i=0; i<N; i++)
        if (i<Rev[i]) swap(a[i],a[ Rev[i] ]);

    for (int len=2; len<=N; len<<=1)
    {
        int mid=(len>>1);
        for (Complex *p=a; p!=a+N; p+=len)
            for (int i=0; i<mid; i++)
            {
                Complex temp=w[mid][i];
                if (f==-1.0) temp.Y=-temp.Y;
                temp=temp*p[mid+i];
                p[mid+i]=p[i]-temp;
                p[i]=p[i]+temp;
            }
    }
}

void FFT()
{
    N=1,Lg=0;
    while (N<n+m+4) N<<=1,Lg++;
    for (int i=0; i<N; i++)
        for (int j=0; j<Lg; j++)
            if (i&(1<<j)) Rev[i]|=(1<<(Lg-j-1));

    int len=1;
    while ((len<<1)<=N)
    {
        double ang=pi/len;
        for (int i=0; i<len; i++)
            w[len].push_back( Complex( cos(ang*(double)i) , sin(ang*(double)i) ) );
        len<<=1;
    }

    for (int i=0; i<N; i++) A[i]=Complex((double)F[i],(double)G[i]);
    DFT(A,1.0);
    for (int i=0; i<N; i++) B[i]=A[(N-i)%N],B[i].Y=-B[i].Y;
    for (int i=0; i<N; i++)
    {
        Complex a=A[i]+B[i];
        a.X/=2.0;
        a.Y/=2.0;
        B[i]=A[i]-B[i];
        B[i].X/=2.0;
        B[i].Y/=2.0;
        swap(B[i].X,B[i].Y);
        B[i].Y=-B[i].Y;
        A[i]=a;
    }
    for (int i=0; i<N; i++) A[i]=A[i]*B[i];
    DFT(A,-1.0);
    for (int i=0; i<N; i++) A[i].X/=((double)N);
    for (int i=0; i<N; i++) F[i]=(int)floor( A[i].X+0.5 );
}

int main()
{
    freopen("3803.in","r",stdin);
    freopen("3803.out","w",stdout);

    scanf("%d%d",&n,&m);
    for (int i=0; i<=n; i++) scanf("%d",&F[i]);
    for (int i=0; i<=m; i++) scanf("%d",&G[i]);

    FFT();
    for (int i=0; i<=n+m; i++) printf("%d ",F[i]);
    printf("\n");

    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章