FFT & NTT學習心得

基礎知識淺談

FFT—快速傅里葉變換

基本功能

在O( (n+m)log(n+m) )的時間複雜度內計算:n次多項式乘m次多項式.

實現方式

欲求多項式A*多項式B:
1.對多項式A,B分別進行快速傅里葉變換,得到A1,B1;
2.將A1,B1的對應項相乘得到多項式C1.即:C1[i]=A1[i]*B1[i] 其中C1[i]表示C1的第i項係數.
3.將多項式C1進行逆變換得到多項式C.則多項式C即是多項式A與多項式B相乘的結果.
具體證明這裏不多說了.

NTT—快速數論變換

基本功能

與FFT類似,但沒有精度差.

實現方式

與FFT類似,只是用數論原根的冪代替FFT中的w[n].

原根的含義及求法

根據歐拉定理,若a,p互質則有:

aϕ(p)1  (mod  p)

如果當且僅當x=ϕ(p)
ax1  (mod  p)

則a是p的一個原根.
用更專業的語言來描述即:a模p的階等於ϕ(p) 則a是p的原根.
原根的求法:沒有什麼特別的方法,只能暴力枚舉,驗證時略有技巧,這裏就不細說了.

關於NTT算法的一些說明

優點

相比起FFT來說,NTT最顯著的優勢在於沒有精度差.由於FFT會用到complex double ,在大數據下不排除出現精度差的可能.在某些評測機上,效率可能不如NTT.(NTT雖然不用複數運算,但是取模很多).

限制

相比起FFT來說,NTT的限制很多.

  • 所求的多項式要求是整係數
  • 如果題目要求結果對質數p取模,這個質數往往只能是998244353,否則會有很多麻煩,這個會在後面談到.
  • 所求多項式的項數應在223 之內,因爲998244353=717223+1
  • 結果的係數不應超過質數P.(P是自己選擇的質數,一般定爲P=998244353)

具體代碼實現

FFT模板

const double pi=3.1415926535897932384626433832795;
complex<double>Wi[MAXN];
char s[MAXN],t[MAXN];
void FFT(complex<double> A[],int nn,int ty)
{
    int i,j,k,m;
    complex<double> t0,t1;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)t0=A[i],A[i]=A[j],A[j]=t0;
    }//這段for循環不建議擅自改動,極易出錯.
    Wi[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t0=exp(complex<double>(0,ty*pi/m));
        for(i=1;i<m;i++)
            Wi[i]=Wi[i-1]*t0;
        for(k=0;k<nn;k+=m<<1)
            for(i=k;i<k+m;i++)
            {
                t0=A[i];
                t1=A[i+m]*Wi[i-k];
                A[i]=t0+t1;
                A[i+m]=t0-t1;
            }
    }
    if(ty==1)
        return;//ty==-1時爲逆變換.
    t0=1.0/nn;
    for(i=0;i<nn;i++)
        A[i]*=t0;
}

NTT模板

const int P=998244353;
const int g=3;//P的原根
int W[MAXN];
int exp(int a,int k)
{
    ll A=1LL*a,ANS=1LL;
    for(;k;k>>=1,A=A*A%P)
    {
        if(k&1)
        {
            ANS=ANS*A%P;
        }
    }
    return (int)ANS%P;
}//其實這裏的函數名取爲exp是不恰當的,不過無傷大雅.
void NTT(int A[],int nn,int ty)
{
    int t1,t2,i,j,k,m;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)
        {
            t1=A[i];
            A[i]=A[j];
            A[j]=t1;
        }
    }
    W[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t1=exp(g,P-1+ty*(P-1)/(m<<1));
        for(i=1;i<m;i++)
        {
            W[i]=1LL*W[i-1]*t1%P;
        }
        for(k=0;k<nn;k+=m<<1)
        {
            for(i=k;i<k+m;i++)
            {
                t1=A[i];
                t2=1LL*A[i+m]*W[i-k]%P;
                A[i]=t1+t2;
                A[i]-=A[i]>P?P:0;
                A[i+m]=t1-t2;
                A[i+m]+=A[i+m]<0?P:0;
            }
        }
    }
    if(ty==1)
    {
        return ;
    }
    t1=exp(nn,P-2);
    for(i=0;i<nn;i++)
    {
        A[i]=1LL*A[i]*t1%P;
    }
    return ;
}

FFT及NTT的簡單應用

最基本應用–高精度乘法

將十進制數看成一個多項式,然後利用FFT或NTT求解.(注意進位).
比如:
123*456=56088可以看成:

123=1102+2101+3100

456=4102+5101+6100

123456=4104+13103+28102+27101+18100

考慮進位則有:
123456=5104+6103+0102+8101+8100=56088

稍微高端一點的應用

特殊的計數問題

題目大意

給出兩個長度分別爲n,m的01序列A、B,有Q次詢問,每次詢問要求回答:當把A的第i位和B的第j位對齊時,A,B公共部分有多少對對齊且相同的數.
如:
輸入n,m,A串,B串,Q,每次詢問(i,j)
6 6
000111
111100
3
1 1
1 2
4 2
輸出每次詢問的答案:
1
0
3

數據範圍:

n<=100000,Q<=1000000

解題方法

注意到所有詢問總共只有2n-1種本質不同的詢問,我們希望預處理出所有這些本質不同的詢問的答案.
比如詢問爲(i,j)我們可以等價轉化爲(1,j-i),因此我們不妨將詢問設爲(1,k)
我們先考慮有多少位上的1是對齊的:
其實兩個位置上的1對齊,當且僅當兩個位置都是1,且兩個位置相差k-1.
如果將其中一個序列倒序排列的話,那麼兩個之前對齊的位置現在的位置標號之和就是定值了.於是就可以構造多項式進行計數.
這之後,我們將所有0,1取反,就可以同樣計算出有多少位0是對齊的.

參考代碼

#include<cstdio>
#include<cstring> 
#define MAXN 425000
const int P=998244353;
const int G=3;
char sa[MAXN],sb[MAXN];
int N,W[MAXN*2];
void _r(int& x)
{
    char c=getchar();
    while(c<'0'||c>'9')
    {
        c=getchar();
    }
    for(x=0;c>='0'&&c<='9';c=getchar())
    {
        x=(x<<1)+(x<<3)+c-'0';
    }
    return ;
}
bool check(int x[],int nn)
{
    for(int i=0;i<nn;i++)
    {
        if(x[i])
        {
            return false;
        }
    }
    return true;
}
int CNT[MAXN*2];
int exp(int a,int k)
{
    int ans=1;
    for(;k;k>>=1)
    {
        if(k&1)
        {
            ans=1ll*ans*a%P;
        }
        a=1ll*a*a%P;
    }
    return ans;
}
void fft(int A[],int nn,int ty)
{
    int t1,t2,i,j,k,m;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)
        {
            t1=A[i];
            A[i]=A[j];
            A[j]=t1;
        }
    }
    W[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t1=exp(G,P-1+ty*(P-1)/(m<<1));
        for(i=1;i<m;i++)
        {
            W[i]=1ll*W[i-1]*t1%P;
        }
        for(k=0;k<nn;k+=m<<1)
        {
            for(i=k;i<k+m;i++)
            {
                t1=A[i];
                t2=1ll*A[i+m]*W[i-k]%P;
                A[i]=t1+t2;
                A[i]-=A[i]>P?P:0;
                A[i+m]=t1-t2;
                A[i+m]+=A[i+m]<0?P:0;
            }
        }
    }
    if(ty==1)
    {
        return ;
    }
    t1=exp(nn,P-2);
    for(i=0;i<N;i++)
    {
        A[i]=1ll*A[i]*t1%P;
    }
    return ;
}//其實是NTT只是函數名取成了fft
int A[MAXN*2],B[MAXN*2],n,m,q;
int o[12],tot=0;
void PUT(int x)
{
    if(x==0)
    {
        putchar('0');
        return ;
    }
    for(tot=0;x;x/=10)
    {
        o[++tot]=x%10;
    }
    for(;tot;tot--)
    {
        putchar(o[tot]+'0');
    }
    return ;
}
int main()
{
    _r(n);
    _r(m);
    scanf("%s%s",sa,sb);
    for(int i=0;i<n;i++)
    {
        A[i]=sa[i]-'0';
    }
    for(int i=0;i<m;i++)
    {
        B[i]=sb[m-i-1]-'0';
    }
    for(N=1;N<n+m+1;N<<=1);
    if(check(A,N)||check(B,N));//特判兩個多項式是否存在爲0的,否則用NTT會有詐.
    else
    {
        fft(A,N,1);
        fft(B,N,1);
        for(int i=0;i<N;i++)
        {
            A[i]=1ll*A[i]*B[i]%P;
        }
        fft(A,N,-1);
        for(int i=0;i<N;i++)
        {
            CNT[i]+=A[i];
        }
    }
    memset(A,0,sizeof(A));
    memset(B,0,sizeof(B));
    for(int i=0;i<n;i++)
    {
        A[i]=sa[i]-'0';
        A[i]^=1;
    }
    for(int i=0;i<m;i++)
    {
        B[i]=sb[m-i-1]-'0';
        B[i]^=1;
    }
    if(check(A,N)||check(B,N));
    else
    { 
        fft(A,N,1);
        fft(B,N,1);
        for(int i=0;i<N;i++)
        {
            A[i]=1ll*A[i]*B[i]%P;
        }
        fft(A,N,-1);
        for(int i=0;i<N;i++)
        {
            CNT[i]+=A[i];
        }
    } 
    _r(q);
    for(int i=1,u,v;i<=q;i++)
    {
        _r(u);
        _r(v);
        int p=u-1+m-v;
        PUT(CNT[p]);
        putchar('\n');
    }
    return 0;
} 

還有一道很牛逼的題目—萬徑人蹤滅,也用到了這個方法.

NTT&FFT與多項式問題

相關知識

  • 多項式求逆
  • 多項式除法
  • 多項式取模
  • 多項式開方
  • 多項式快速冪
  • 多項式取對數
  • ………………

這些東西本人也是一知半解,具體算法的證明過程,代碼實現可以參考網上的其他博客。這裏給一個參考的鏈接(加載非常慢,但內容非常好,如有需要請耐心等待):http://picks.logdown.com/posts/197262-polynomial-division

簡單地談一談

多項式求逆

已知多項式A,欲求多項式B,使AB1 (mod xn) .
首先,常數項取B[0]=A[0]1
假如我們已經得到B’,使得AB1 (mod xn) ,欲求B使得AB1 (mod x2n)
那麼有

A(BB)0 (mod xn)

於是
BB0 (mod xn)

平方
B2+B22BB0 (mod x2n)

同時乘A得
AB2+AB22ABB0 (mod x2n)

B+AB22B0 (mod x2n)

於是得到
B2BAB2 (mod x2n)

這裏用到了倍增的思想.
多項式開方與之類似.
時間複雜度
T(n)=T(n/2)+O(nlog n),O(nlog n)

一些典型題目

貝殼串

【問題描述】
海邊市場有長度分別爲1到n的貝殼串出售,其中長度爲i的貝殼串有a[i]種,每種貝殼串有無限個,問用這些貝殼串鏈接成長度爲n的串有多少種方案?

【輸入格式】
第一行,一整數n,
第二行,n個整數ai表示長度爲i的貝殼串的種類數

【輸出格式】
輸出方案數,結果模313

【輸入輸出樣例】
in
3
1 3 7
out
14

in
4
2 2 2 2
out
54

int
7
86 58 87 145 510 32 263
out
152
【數據範圍】
對於50%的數據n<=1000
對於100%的數據n<=100000,0<=ai<=10000000

官方題解

FFT+分治

另解

FFT+多項式求逆

然而我是一個異端

題目要求對313取模,本是斷絕了NTT的路,但我硬是強用NTT把這個題過了!
做法非常的“淫蕩”.爲此我專門寫了一篇題解:

另類題解

【貝殼串shell】的另類解法—-強行NTT.

1.DP方程
設f[i]爲配成長度爲i的項鍊的方案數,a[i]表示長度爲i的貝殼的種類數.
枚舉最後一串貝殼的長度,可得轉移方程:

f[i]=i=1nf[ni]a[i]

並規定f[0]=1 .
直接暴力DP可以拿到50分.

2.構造母函數

F(x)=i=0+f[i]xi        A(x)=i=0+a[i]xi

於是根據dp方程有:
F(x)=F(x)A(x)+1

等價於:
F(x)=11A(x)

到此爲止用FFT就足以解決問題了.

3.強用NTT
我們知道,多項式求逆需要多次使用NTT,這樣不可避免地會使係數對質數P(P=998244353)取模.但是題目要求對313取模,這樣得到的答案就不對了.

舉個例子吧:
用NTT計算(1+x)^n的一次項係數,答案對9973取模.如果我們直接把多項式進行(n-1)次乘法,那麼得到的結果將是n%998244353%9973,這與n%9973是不等價的.
但是如果在每次乘以(x+1)後,將結果的各項係數都對9973取模,得到的將是正確的結果.因爲NTT實際上是在大質數的剩餘系下進行乘法運算,如果單次NTT變換的結果不會超出大質數,所得結果將與正常的乘法一樣.

所以,貝殼串這道題也可以使用這樣的辦法.

很遺憾的是,就貝殼串一題而言,單次運算的上界=313*313*100000,大約100億左右,顯然超出了998244353之類的數的範圍,怎麼辦?

我們可以在P=998244353*1004535809的剩餘系下完成乘法。(998244353,1004535809都是常見的用於NTT的大質數,本人更喜歡用前者.)我們可以在每次NTT之後都把係數對313取模,這樣中間過程一直都不會超過P.

但是P=998244353*1004535809不能直接用於NTT,於是我們可以把多項式分別在998244353,1004535809的剩餘系下進行NTT變換,之後用中國剩餘定理求出係數模P的結果.

不過還有一個坑點,這樣得到的結果不一定是真實的結果,不能簡單地對313取模.按照以上做法,得到的結果除了可能是真實結果之外,還有可能是(真實結果+P),這時的真實結果是負數.因爲在做NTT變換時,負數一般是+P轉爲正數的.好在負數結果+P的值要遠大於正數結果,我們可以很簡單區分開二者,然後加以特判即可.其實這個問題非常容易被忽視!考場上我起初是沒有發現這個問題的,調了將近一個小時才發現這個問題。

至此這個問題終於講清楚了,貼上我那醜陋的代碼!

//1004535809 998244353
#include<cstdio>
#define MAXN 450000
#define LL long long
#define mod 313
const int P=998244353,Q=1004535809,G=3;
const LL PQ=1ll*998244353*1004535809,lim=1ll*mod*mod*400000;
int exp(int a,int k,int MOD)
{
    int an=1;
    for(;k;k>>=1)
    {
        if(k&1)
        {
            an=1ll*an*a%MOD;
        }
        a=1ll*a*a%MOD;
    }
    return an;
}
const int ip=exp(P,Q-2,Q),iq=exp(Q,P-2,P);
int A[MAXN],F[MAXN];
int n;
void _r(int& x)
{
    char c=getchar();
    while(c<'0'||c>'9')
    {
        c=getchar();
    }
    for(x=0;c>='0'&&c<='9';c=getchar())
    {
        x=(x<<1)+(x<<3)+c-'0';
    }
    return ;
}
void work1()
{
    F[0]=1;
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=i;j++)
        {
            F[i]=(F[i]+F[i-j]*A[j]%mod)%mod;
        }
    }
    printf("%d\n",F[n]);
    return ;
}
int W[MAXN];
void NTT(int A[],int nn,int ty,int M)
{
    int i,j,t1,t2,k,m;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)
        {
            t1=A[i];
            A[i]=A[j];
            A[j]=t1;
        }
    }
    W[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t1=exp(G,M-1+ty*(M-1)/(m<<1),M);
        for(i=1;i<m;i++)
        {
            W[i]=1ll*W[i-1]*t1%M;
        }
        for(k=0;k<nn;k+=m<<1)
        {
            for(i=k;i<k+m;i++)
            {
                t1=A[i];
                t2=1ll*A[i+m]*W[i-k]%M;
                A[i]=t1+t2;
                A[i]-=A[i]>=M?M:0;
                A[i+m]=t1-t2;
                A[i+m]+=A[i+m]<0?M:0;
            }
        }
    }
    if(ty==1)
    {
        return ;
    }
    t1=exp(nn,M-2,M);
    for(i=0;i<nn;i++)
    {
        A[i]=1ll*A[i]*t1%M;
    }
    return ;
}
LL mul(LL a,LL b)
{
    LL an=0;
    for(;b;b>>=1)
    {
        if(b&1)
        {
            an=an+a;
            an-=an>=PQ?PQ:0;
        }
        a=a+a;
        a-=a>=PQ?PQ:0;
    }
    return an;
}
int T1[MAXN],T2[MAXN],tmp[MAXN];
void inv(int A[],int X[],int nn)
{
    int i,j,k,t,m,n;
    LL t5;
    X[0]=exp(A[0],mod-2,mod);
    for(m=1;m<nn;m<<=1)
    {
        n=m<<1;
        for(i=0;i<n;i++)
        {
            T1[i]=A[i];
            T2[i]=A[i];
        }
        for(n<<=1;i<n;i++)
        {
            T1[i]=0;
            T2[i]=0;
        }
        for(i=0;i<n;i++)
        {
            tmp[i]=X[i];
        }
        NTT(T1,n,1,P);
        NTT(tmp,n,1,P);
        for(i=0;i<n;i++)
        {
            T1[i]=2+P-1ll*T1[i]*tmp[i]%P;
            T1[i]-=T1[i]>=P?P:0;
            T1[i]+=T1[i]<0?P:0;
            tmp[i]=1ll*tmp[i]*T1[i]%P;
        }
        NTT(tmp,n,-1,P);
        NTT(T2,n,1,Q);
        NTT(X,n,1,Q);
        for(i=0;i<n;i++)
        {
            T2[i]=2+Q-1ll*T2[i]*X[i]%Q;
            T2[i]-=T2[i]>=Q?Q:0;
            T2[i]+=T2[i]<0?Q:0;
            X[i]=1ll*X[i]*T2[i]%Q;
        }
        NTT(X,n,-1,Q);
        for(i=0;i<n;i++)
        {
            t5=(mul(mul(1ll*X[i],1ll*P),1ll*ip)+mul(mul(1ll*tmp[i],1ll*Q),1ll*iq))%PQ;
            if(t5>lim)
            {
                X[i]=t5%mod-PQ%mod;
                X[i]+=X[i]<0?mod:0;
            }
            else
            {
                X[i]=t5%mod;
            }
        }
        for(i=n>>1;i<n;i++)
        {
            X[i]=0;
        }
    }
    for(i=nn;i<n;i++)
    {
        X[i]=0;
    }
    for(i=0;i<n;i++)
    {
        tmp[i]=T1[i]=T2[i]=0;
    }
    return ;
}       
void work2()
{
    A[0]=1;
    for(int i=1;i<=n;i++)
    {
        A[i]%=mod;
        A[i]=mod-A[i];
    }
    inv(A,F,n+1);
    printf("%d\n",(F[n]%mod+mod)%mod);
    return ;
}
int main()
{
    //freopen("shell.in","r",stdin);
    //freopen("shell.out","w",stdout);
    _r(n);
    for(int i=1;i<=n;i++)
    {
        _r(A[i]);
    }
    if(n<=1100)
    {
        work1();
    }
    else
    {
        work2();
    }
    return 0;
}

上面的代碼很亂,是考場上寫的,有很多冗餘的變量,將就了看吧.
對了,忘記說了,P取得相當大,用中國剩餘定理的時候,乘法要用快速乘,不然會超long long.


小朋友與二叉樹

題目描述參見bzoj 3625 / Codeforces Round #250

題解

這道題是一個計數問題,容易想到dp的做法.
設f[n]表示點權和爲n的二叉樹有多少種.考慮n的組成:
左子樹點權和+右子樹點權和+根節點權值=n.
設權值爲k的數給定序列中出現次數爲c[k],則根據分步計數原理:

f[n]=i+j+k=nf[i]f[j]c[k]

構造多項式:
F(x)=i=0+f[i]xi      C(x)=i=0+c[i]xi

則有:
F(x)=F(x)F(x)C(x)

???有詐!!!
我們在遞推式中默認F(0)=f[0]=1.這樣纔可以包含子樹爲空的情形.但是C(0)=c[0]=0.
於是F(0)=F(0)*F(0)*C(0)=C(0)=0,推出矛盾!
所以正確的式子:
F(x)=F(x)F(x)C(x)+1

解得:
F(x)=1±14C(x)2C(x)

這種方程真的能這麼解嗎?其實只需要知道,把這個解代入原方程,方程肯定成立.我能解釋的也只有這麼多了.
注意到,分母沒有常數項,分子要整除分母,則分子也不能有常數項,即±應取負號.
爲了簡單,再做一下變形:
F(x)=1±14C(x)2C(x)=21+14C(x)

用多項式開方即可.

參考代碼

算了,這題我就不給代碼了!本人2017年3月1日在bzoj提交了這道題,並在當時莫名其妙地超過了各位大牛,拿到了rank 1.如果放出代碼,rank 1豈不是保不住了!!!


bzoj 4555

第二類斯特林數

這個題目做法很多,NTT+分治或者多項式求逆都可以做.
只要式子化得好,直接上裸的NTT也是可以做的.
網上題解很多,這裏就不多說了.

參考代碼(不用求逆或分治的常數最小的版本)

這個做法的數學推導比較妙,建議先學習一下第二類斯特林數的展開式.

#include<cstdio>
#define MAXN 410000
#define P 998244353
#define g 3
int n,A[MAXN],B[MAXN],N;
int exp(int a,int k,int mod=P)
{
    int an=1;
    for(;k;k>>=1)
    {
        if(k&1)
        {
            an=1ll*an*a%mod;
        }
        a=1ll*a*a%mod;
    }
    return an;
}
int fac[MAXN],inv[MAXN],W[MAXN];
void NTT(int A[],int nn,int ty)
{
    int i,j,k,m,t1,t2;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)
        {
            t1=A[i];
            A[i]=A[j];
            A[j]=t1;
        }
    }
    W[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t1=exp(g,P-1+ty*(P-1)/(m<<1),P);
        for(i=1;i<m;i++)
        {
            W[i]=1ll*W[i-1]*t1%P;
        }
        for(k=0;k<nn;k+=m<<1)
        {
            for(i=k;i<k+m;i++)
            {
                t1=A[i];
                t2=1ll*A[i+m]*W[i-k]%P;
                A[i]=t1+t2;
                A[i]-=A[i]>=P?P:0;
                A[i+m]=t1-t2;
                A[i+m]+=A[i+m]<0?P:0;
            }
        }
    }
    if(ty==1)
    {
        return ;
    }
    t1=exp(nn,P-2,P);
    for(i=0;i<nn;i++)
    {
        A[i]=1ll*A[i]*t1%P;
    }
    return ;
}
int main()
{
    scanf("%d",&n);
    fac[0]=inv[0]=1;
    for(int i=1;i<=n;i++)
    {
        inv[i]=1ll*inv[i-1]*i%P;
        fac[i]=inv[i];
    }
    inv[n]=exp(inv[n],P-2,P);
    for(int i=n;i;i--)
    {
        inv[i-1]=1ll*inv[i]*i%P;
    }
    for(int i=0,p;i<=n;i++)
    {
        if(i&1)
        {
            A[i]=P-inv[i];
        }
        else
        {
            A[i]=inv[i];
        }
        if(i==0)
        {
            B[i]=1;
        }
        else if(i==1)
        {
            B[i]=1ll*(n+1)*inv[i]%P;
        }
        else
        {
            B[i]=1ll*(exp(i,n+1)-1+P)%P*exp(i-1,P-2,P)%P*inv[i]%P;
        }
    }
    for(N=1;N<=n*2;N<<=1);
    NTT(A,N,1);
    NTT(B,N,1);
    for(int i=0;i<N;i++)
    {
        A[i]=1ll*A[i]*B[i]%P;
    }
    NTT(A,N,-1);
    int ans=0,tmp;
    for(int i=0;i<=n;i++)
    {
        tmp=1ll*exp(2,i,P)*fac[i]%P;
        tmp=1ll*tmp*A[i]%P;
        ans=ans+tmp;
        ans-=ans>=P?P:0;
    }
    printf("%d\n",ans);
    return 0;
}

【SDOI2013 R1 Day2】淘金

題目來源:sdoi 2013

題號:bzoj 3131

題目描述

小Z在玩一個叫做《淘金者》的遊戲。遊戲的世界是一個二維座標。X軸、Y軸座標範圍均爲1..N。初始的時候,所有的整數座標點上均有一塊金子,共N*N塊。
一陣風吹過,金子的位置發生了一些變化。細心的小Z發現,初始在(i,j)座標處的金子會變到(f(i),f(j))座標處。其中f(x)表示x各位數字的乘積,例如f(99)=81,f(12)=2,f(10)=0。如果金子變化後的座標不在1..N的範圍內,我們認爲這塊金子已經被移出遊戲。同時可以發現,對於變化之後的遊戲局面,某些座標上的金子數量可能不止一塊,而另外一些座標上可能已經沒有金子。這次變化之後,遊戲將不會再對金子的位置和數量進行改變,玩家可以開始進行採集工作。
小Z很懶,打算只進行K次採集。每次採集可以得到某一個座標上的所有金子,採集之後,該座標上的金子數變爲0。
現在小Z希望知道,對於變化之後的遊戲局面,在採集次數爲K的前提下,最多可以採集到多少塊金子?
答案可能很大,小Z希望得到對1000000007(10^9+7)取模之後的答案。

數據規模

N1012,k105.

題解

一些簡單的想法:
  • 處理出每個位置的金子數
  • 如果設g(x)表示所有滿足f(i)=x的i的數目,那麼(x,y)的金子數爲g(x)*g(y)
  • 如果我們處理出了g(x)的值,那麼這就是一個經典的用堆處理的問題了.
  • 處理g(x)應該可以用數位dp.
官方題解:數位dp+堆.
另一些想法:
  • 是不是可以暴力求解g(x)?
  • g(x)!=0,意味着x=2a3b5c7d ,這樣的x並不多!
  • 如果N106 ,我們暴力求解即可
  • N1012,N106106 我們暴力處理出106 以內的g(x),然後用NTT做一次乘法就可以得出1012 以內的g(x)了!
  • 顯然x可以大到912 ,不過我們可以分解質因數用a,b,c,d存儲.爲了方便起見,估計a,b,c,d的最大值,然後採取合適的進制就可以了.
  • 於是這道題居然就可以用NTT+暴力就水過去了!
  • 當然這個乘法並不簡單,考慮到數位dp的特點,我們在做乘法時要分有限制和無限制兩部分,具體實現看代碼吧

參考代碼

醜的要命的代碼:

#include<cstdio>
#include<queue>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 1080000
#define mod 1000000007
#define LL long long
#define G 998244353
#define K 1004535809
#define gg 3
int exp(int a,int k)
{
    int an=1;
    for(;k;k>>=1)
    {
        if(k&1)
        {
            an=1ll*an*a%G;
        }
        a=1ll*a*a%G;
    }
    return an;
}
int W[MAXN];
void NTT(int A[],int nn,int ty)
{
    int i,j,t1,t2,k,m;
    for(i=0;i<nn;i++)
    {
        for(j=0,k=i,m=1;m<nn;m<<=1,j=(j<<1)|(k&1),k>>=1);
        if(i<j)
        {
            t1=A[i];
            A[i]=A[j];
            A[j]=t1;
        }
    }
    W[0]=1;
    for(m=1;m<nn;m<<=1)
    {
        t1=exp(gg,G-1+ty*(G-1)/(m<<1));
        for(i=1;i<m;i++)
        {
            W[i]=1ll*W[i-1]*t1%G;
        }
        for(k=0;k<nn;k+=m<<1)
        {
            for(i=k;i<k+m;i++)
            {
                t1=A[i];
                t2=1ll*A[i+m]*W[i-k]%G;
                A[i]=t1+t2;
                A[i]-=A[i]>=G?G:0;
                A[i+m]=t1-t2;
                A[i+m]+=A[i+m]<0?G:0;
            }
        }
    }
    if(ty==1)
    {
        return ;
    }
    t1=exp(nn,G-2);
    for(i=0;i<nn;i++)
    {
        A[i]=1ll*A[i]*t1%G;
    }
    return ;
}
LL n,k;
int f[MAXN],g[MAXN],tot=0,h[MAXN],oto=0;
int ans=0;
void work1()
{
    for(int i=1,tt;i<=n;i++)
    {
        tt=1;
        for(int x=i;x;x/=10)
        {
            tt*=x%10;
        }
        f[tt]++;
    }
    for(int i=1;i<=n;i++)
    {
        if(f[i])
        {
            g[++tot]=f[i];
        }
    }
    sort(g+1,g+1+tot);
    for(int i=tot;i>=1&&i>=tot-1000;i--)
    {
        for(int j=tot;j>=1&&j>=tot-1000;j--)
        {
            h[++oto]=1ll*g[i]*g[j]%mod;
        }
    }
    sort(h+1,h+1+oto);
    for(int i=oto;i>=1&&i>=oto-k+1;i--)
    {
        ans+=h[i];
        ans-=ans>=mod?mod:0;
    }
    printf("%d\n",ans);
    return ;
}//分段,前面小數據直接暴力!
const int QQ=1000000;
LL n1,n2,m1,m2,AA[MAXN*3];
int pp[5]={0,2,3,5,7},qq[5]={0,40,26,18,15};
int A[MAXN],B[MAXN],C[MAXN],D[MAXN],NN=1048576;
struct node
{
    int x,y;
    LL val;
    node(int xx=0,int yy=0)
    {
        x=xx;
        y=yy;
        val=1ll*D[x]*D[y];
    }
};
bool cmp(LL a,LL b)
{
    return a>b;
}
bool operator < (node x,node y)
{
    return x.val<y.val;
}
priority_queue<node>Q;
void work2()
{
    n1=n/QQ;
    n2=n1-1;
    m1=n%QQ;
    m2=QQ-1;
    memset(f,0,sizeof(f)); 
    for(int i=QQ/10,tt;i<=m2;i++)
    {
        tt=1;
        for(int x=i;x;x/=10)
        {
            tt*=x%10;
        }
        f[tt]++;
    }
    tot=0;
    for(int i=1;i<=m2;i++)
    {
        if(f[i])
        {
            int tt=0;
            for(int j=4,q=i;j>=1;j--)
            {
                tt*=qq[j];
                while(q%pp[j]==0)
                {

                    q/=pp[j];
                    ++tt;
                }
            }
            D[tt]=f[i];
        }
    }
    memset(f,0,sizeof(f)); 
    for(int i=QQ/10,tt;i<=m1;i++)
    {
        tt=1;
        for(int x=i;x;x/=10)
        {
            tt*=x%10;
        }
        f[tt]++;
    }
    tot=0;
    for(int i=1;i<=m1;i++)
    {
        if(f[i])
        {
            int tt=0;
            for(int j=4,q=i;j>=1;j--)
            {
                tt*=qq[j];
                while(q%pp[j]==0)
                {

                    q/=pp[j];
                    ++tt;
                }
            }
            C[tt]=f[i];
        }
    }
    memset(f,0,sizeof(f)); 
    for(int i=1,tt;i<=n2;i++)
    {
        tt=1;
        for(int x=i;x;x/=10)
        {
            tt*=x%10;
        }
        f[tt]++;
    }
    tot=0;
    for(int i=1;i<=n2;i++)
    {
        if(f[i])
        {
            int tt=0;
            for(int j=4,q=i;j>=1;j--)
            {
                tt*=qq[j];
                while(q%pp[j]==0)
                {

                    q/=pp[j];
                    ++tt;
                }
            }
            B[tt]=f[i];
        }
    }
    int ttt=1;
    for(int x=n1;x;x/=10)
    {
        ttt*=x%10;
    }
    int tt=0;
    if(ttt)
    {
        for(int j=4,q=ttt;j>=1;j--)
        {
            tt*=qq[j];
            while(q%pp[j]==0)
            {
                q/=pp[j];
                ++tt;
            }
        }
    }
    A[tt]=1;
    NTT(A,NN,1);
    NTT(C,NN,1);
    for(int i=0;i<NN;i++)
    {
        A[i]=1ll*A[i]*C[i]%G;
    }
    NTT(A,NN,-1);
    NTT(B,NN,1);
    NTT(D,NN,1);
    for(int i=0;i<NN;i++)
    {
        B[i]=1ll*B[i]*D[i]%G;
    }
    NTT(B,NN,-1);
    for(int i=0;i<NN;i++)
    {
        C[i]=A[i]+B[i];
        C[i]-=C[i]>=mod?mod:0;
    }
    memset(f,0,sizeof(f)); 
    memset(D,0,sizeof(D));
    for(int i=1,tt;i<=m2;i++)
    {
        tt=1;
        for(int x=i;x;x/=10)
        {
            tt*=x%10;
        }
        f[tt]++;
    }
    tot=0;
    for(int i=1;i<=m2;i++)
    {
        if(f[i])
        {
            int tt=0;
            for(int j=4,q=i;j>=1;j--)
            {
                tt*=qq[j];
                while(q%pp[j]==0)
                {

                    q/=pp[j];
                    ++tt;
                }
            }
            D[tt]=f[i];
        }
    }
    for(int i=0;i<NN;i++)
    {
        C[i]+=D[i];
        C[i]-=C[i]>=mod?mod:0;
    }
    int tttt=0;
    for(int i=0;i<NN;i++)
    {
        if(C[i])
        {
            D[++tttt]=C[i];
        }
    }
    sort(D+1,D+1+tttt,cmp);
    Q.push(node(1,1));
    int cnt=0;
    node q;
    LL ans=0;
    while(Q.size())
    {
        q=Q.top();
        Q.pop();
        ans+=q.val%mod;
        ans-=ans>=mod?mod:0;
        ++cnt;
        if(cnt==k)
        {
            break;
        }
        if(q.x!=q.y)
        {
            ans+=q.val%mod;
            ans-=ans>=mod?mod:0;
            ++cnt;
            if(cnt==k)
            {
                break;
            }
            if(q.x+1<=tttt)
            {
                Q.push(node(q.x+1,q.y));
            }
        }
        if(q.x==1&&q.y+1<=tttt)
        {
            Q.push(node(q.x,q.y+1));
        }
    }
    printf("%I64d\n",ans);
    return ;
}
int main()
{
    //freopen("gold.in","r",stdin);
    //freopen("gold.out","w",stdout);
    cin>>n>>k;
    if(n<=2000000)
    {
        work1();
    }
    else
    {
        work2();
    }
    return 0;
}

代碼中的A,B,C,D數組是什麼鬼?t,tt,ttt,tttt記的是什麼?這些嘛,自己看自己想,反正就是亂整就對了.

至於爲什麼NTT不會出現結果超出998244353的情況,我只能說,打表發現,A,B,C,D這些東西並不是很大.

這份代碼在bzoj上光榮拿到了倒數第2(用戶:problem_set),想想吧,人家數位dp自然是比這個暴力+NTT要快很多的

發佈了31 篇原創文章 · 獲贊 6 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章