HDOJ 1402. A * B Problem Plus (FFT快速傅里葉變換)

Problem Description
Calculate A * B.

Input
Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.

Output
For each case, output A * B in one line.

這是一道套用FFT模板的題目,因爲剛學習了FFT算法知識,就拿來練手。

對於兩個多項式相乘問題,a0+a1x1+a2x2+...+an1xn1b0+b1x1+b2x2+...+bn1xn1 ,FFT可以通過求值和插值的方法,獲得O(nlgn) 的時間複雜度。

求值:對於函數A(x)=a0+a1x1+a2x2+...+an1xn1 ,求得n個不同的點值(xi,A(xi))

對函數公式進行變換

(1)A(x)=a0+a1x1+a2x2+...+an1xn1(2)=(a0+a2x2+a4x4+...+an2xn2)+(a1x1+a3x3+a5x5+...+an1xn1)(3)=(a0+a2x2+a4x4+...+an2xn2)+x(a1+a3x2+a5x4+...+an1xn2)(4)(assume that n is even)

試想一下,如果我們所要求的A(x1),A(x2)x12=x22 ,那麼對於上式中的兩個括號中的值,只需要計算一次即可,然後分別進行一次乘法和一次加法。對於求n個A(xi) 的值,如果我們選取的n的xi ,兩兩之間能夠滿足xi2=xj2 ,我們的計算量相當於是進行了折半。但還是不夠。

如果這個過程能夠遞歸地進行下去,也就是我們對a0+a2x2+a4x4+...+an2xn2a1+a3x2+a5x4+...+an1xn2 也能夠用同樣的方式求得n/2個點的值(這裏進行一個變量替換y=x2 ,便可以得到類似A(x) 的規模更小的多項式,所以我們能進行遞歸),那麼這個遞歸算法爲T(n)=2T(n/2)+O(n) ,根據Master Theory,可以知道算法的複雜度爲O(nlgn) 。要使這個算法能夠遞歸地進行下去,要達到的條件是在遞歸的每一層,我們總能夠找到兩兩成對的xi2=xj2

要達到這種要求,就需要在複數空間對1開n次方根得到n個值wn0,wn1,wn2,...wnn1 ,到下一層遞歸時又可以繼續下去。這裏數學的推導相對複雜,對於複數的知識以及具體遞歸的處理可以參考《算法導論》FFT章節。

對兩個多項式分別進行求值之後,進行點值乘法可以得到結果多項式上的n個點值(xi,C(xi))C(xi)=A(xi)B(xi)

插值:得到n個點值(xi,C(xi)) 後,求A(x)=c0+c1x1+c2x2+...+cn1xn1 中n個係數值c0,c1,c2...cn1

從數學上推導過來,在求值(xk=wnk,A(xk))

A(xk)=j=0n1ajwnkj(k=0,1,...,n1)

在插值時
ck=1nj=0n1C(xj)wnkj(j=0,1,...,n1)

(具體的數學推導請參考其他資料)這樣來看,插值和求值變爲了類似的過程。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <complex>
#include <algorithm>
using namespace std;

#define PI 3.14159265358979323846
#define MAX_N 1 << 17                 // error

char a[MAX_N], b[MAX_N];
complex<double> A[MAX_N], B[MAX_N], temp[MAX_N];
int res[MAX_N];

void reverse_copy(char* a, complex<double>* A, int n, int k) {
    // n >>= 1; if n == 1, then n = 0   // error
    for (int i = 0; i < n / 2; i++)
        swap(a[i], a[n - 1 - i]);
    // n <<= 1; if n == 0, then n = 0
    for (int i = 0; i < k; i++)
        A[i] = (i < n) ? complex<double>(a[i] - '0') : complex<double>(0);
}

int rev(int k, int lg_n) {
    int r = 0;
    for (int i = 0; i < lg_n; i++) {
        r <<= 1;
        r |= (k & 1);
        k >>= 1;
    }
    return r;
}
/*
void bit_reverse_copy(complex<double>* A, int k) {
    if (k == 1) { return; }
    int lg_k = 0;
    for (int i = 1; i < k; i <<= 1, lg_k++);
    for (int i = 0; i < k; i++) temp[i] = A[i];
    for (int i = 0; i < k; i++)
        A[rev(i, lg_k)] = temp[i];
}
*/

void bit_reverse_swap(complex<double>* A, int n) {
    int lg_n = 0;
    for (int i = 1; i < n; i <<= 1, lg_n++);
    for (int k = 0; k < n; k++)
        if (k < rev(k, lg_n))
            swap(A[k], A[rev(k, lg_n)]);
}

void FFT(complex<double>* A, int n, int flag) {
    // bit_reverse_copy(A, n);
    bit_reverse_swap(A, n);
    int s, j, k, t, st, lg_n = 1;
    complex<double> w, u, v, w_n;
    for (int i = 2; i < n; i <<= 1, lg_n++);
    for (s = 0; s < lg_n; s++) {
        int l = 1 << (s + 1);
        w_n = complex<double>(cos(flag * 2 * PI / l),
                              sin(flag * 2 * PI / l));
        for (t = 0; t < n / l; t++) {
            w = 1;
            st = t * l;
            for (j = 0; j < (1 << s) ; j++) {
                u = A[st + j];
                v = A[st + j + (1 << s)];
                v *= w;
                A[st + j] += v;
                A[st + j + (1 << s)] = u - v;
                w *= w_n;
            }
        }
    }
    // int s, j, k, st, lg_n = 1;
    // complex<double> w, u, v, w_n, t;
    // for (int i = 2; i < n; i <<= 1, lg_n++);
    // for (s = 1; s <= lg_n; s++) {
    //     int m = 1 << s;
    //     w_n = complex<double>(cos(flag * 2 * PI / m), sin(flag * 2 * PI / m));
    //     for (k = 0; k <= n - 1; k += m) {
    //         w = 1;
    //         for (j = 0; j <= m / 2 - 1; j++) {
    //             t = w * A[k + j + m / 2];
    //             u = A[k + j];
    //             A[k + j] = u + t;
    //             A[k + j + m / 2] = u - t;
    //             w *= w_n;
    //         }
    //     }
    // }
    if (flag == -1) {  
        for (int i = 0; i < n; i++) {
            A[i] /= complex<double>(n);
        }
    } 
}


int main() {

    while (~scanf("%s%s", a, b)) {        
        int n = strlen(a), m = strlen(b), t = n + m;
        int k = 1;
        for (; k < t; k <<= 1);

        reverse_copy(a, A, n, k);
        reverse_copy(b, B, m, k);

        FFT(A, k, 1);
        FFT(B, k, 1);
        for (int i = 0; i < k; i++) {
            A[i] *= B[i];
        }
        FFT(A, k, -1);
        for (int i = 0; i < k; i++) {
            res[i] = int(A[i].real() + 0.5);
        }
        for (int i = 0; i < k - 1; i++) {
            if (res[i] >= 10) {
                // res[i + 1] = res[i] / 10;    // error
                res[i + 1] += res[i] / 10;
                res[i] %= 10;
            }
        }

        int j;
        for (j = k - 1; j >= 0 && res[j] == 0; j--);
        if (j < 0) {
            printf("0");
        } else {
            for (; j >= 0; j--) putchar(res[j] + '0');
        }
        puts("");
    }
    return 0;
}

貼上代碼,以上代碼參考了FFT 模板,基本上是在用這裏的代碼來debug,然後逐漸替換成自己的代碼。

這道題我一開始認爲每個數最大的長度爲50000,兩數相乘最多是100000,所以任意取了MAX_N 120000。但是因爲我們總是要把兩數的位數和補成2的次方,所以在計算過程中用的n一定是大於100000,會達到1<<17,所以WA了好久。

在非遞歸的實現中,有一個優化的地方是本來是要A[rev(k)]=ak ,這樣的話需用用另一個數組來輔助存儲。但是因爲k與rev(k)互爲二進制位的位逆序,所以交換即可。

參考
1. http://www.cnblogs.com/Patt/p/5503322.html
2. 《算法導論》

發佈了36 篇原創文章 · 獲贊 3 · 訪問量 7119
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章