UOJ#86:mx的組合數 (Lucas定理+原根+NTT+高精度)

題目傳送門:http://uoj.ac/problem/86


題目分析:高精度寫死人系列,我寫了一個晚上才寫完QAQ。

一開始拿到這題沒什麼頭緒,然後從部分分開始想。上數學課的時候忽然間發現40分的部分分就是個暴力枚舉+Lucas定理。根據:

Cmn=CmpnpCmmodpnmodp

直接枚舉m[L,R] ,然後遞歸到m<p,n<p 時退出即可。

然後我們發現這個遞歸大概展開logp(R) 層,而且這很像個數位DP。於是我們不妨對原問題差分,用函數Solve(k,n) 求出當m[0,k] 時,Cmna(modp)(0<=a<p) 的答案。很明顯可以先調用Solve(kp1,np) ,將其答案與Cxnmodp(0<=x<p) 構成的數組進行合併,然後單獨處理C(kp,np)Cxnmodp(0<=x<=kmodp) 的部分。我們發現前者中兩個數組的合併是下標乘積的形式,而p又是個質數,所以可以轉化爲原根的冪然後做NTT。由於不是很好處理模p等於0的情況,可以先算出模p不爲0的情況,最後再用R-L+1減去。最後的複雜度是plog(p)logp(R)

總之這是個十分套路的題目,然而要寫高精度所以比較煩。而且我一開始還寫錯了幾個地方:一是[0,R]中有R+1個數,我以爲是R個數QAQ;二是遞歸的退出條件不一定是m<p,n<p ,還有可能在別的一些地方……;三是我把對ans[0]的減法放在了if (l.num[0])裏面,這樣l=0就炸了。

隨手寫一發,我的code居然排到了rk2?!


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;

const int maxn=1000000;
const int maxl=50;
const long long M=998244353;
const long long g=3;
typedef long long LL;

LL A[maxn];
LL B[maxn];

int Rev[maxn];
int N,Lg;

struct Big_Int
{
    LL num[maxl];
    void Down() { while ( !num[ num[0] ] && num[0] ) num[0]--; }
} n,n1,l,r;
char s[maxl];

#define P pair<Big_Int,long long>
#define MP(x,y) make_pair(x,y)

LL fac[maxn];
LL nfac[maxn];

int id[maxn];
int rid[maxn];

LL ans[maxn];
LL p,pg;

LL Pow(LL x,LL y,LL z)
{
    if (!y) return 1LL;
    LL temp=Pow(x,y>>1,z);
    temp=temp*temp%z;
    if (y&1) temp=temp*x%z;
    return temp;
}

LL Get(LL x)
{
    x--;
    LL mz=(long long)floor( sqrt( (double)x )+0.5 );
    for (LL y=2; y<=x; y++)
    {
        bool sol=true;
        for (LL z=2; z<=mz; z++)
            if (x%z==0)
            {
                if ( Pow(y,z,p)==1LL ) sol=false;
                if ( Pow(y,x/z,p)==1LL ) sol=false;
                if (!sol) break;
            }
        if (sol) return y;
    }
}

void Read(Big_Int &x)
{
    scanf("%s",s);
    int len=strlen(s);
    x.num[0]=len;
    for (int i=0; i<len; i++) x.num[len-i]=s[i]-'0';
    x.Down();
}

P Div(Big_Int x,LL y)
{
    for (int i=x.num[0]; i>=2; i--)
    {
        x.num[i-1]+=(x.num[i]%y*10LL);
        x.num[i]/=y;
    }
    LL z=x.num[1]%y;
    x.num[1]/=y;
    x.Down();
    return ( MP(x,z) );
}

LL Change(Big_Int x)
{
    LL y=0;
    for (int i=x.num[0]; i>=1; i--)
    {
        y=y*10+x.num[i];
        if (y>=p) return -1LL;
    }
    return y;
}

LL CI(int x,int y)
{
    if (y>x) return 0;
    LL v=fac[x];
    v=v*nfac[y]%p;
    v=v*nfac[x-y]%p;
    return v;
}

void Dec(Big_Int &x)
{
    x.num[1]--;
    int y=1;
    while (x.num[y]<0) x.num[y]+=10,x.num[++y]--;
    x.Down();
}

void DFT(LL *a,int f)
{
    for (int i=0; i<N; i++)
        if (i<Rev[i]) swap(a[i],a[ Rev[i] ]);

    for (int len=2; len<=N; len<<=1)
    {
        int mid=(len>>1);
        LL e=Pow(g,(M-1)/len,M);
        if (f==-1) e=Pow(e,M-2,M);

        for (LL *p=a; p!=a+N; p+=len)
        {
            LL wn=1;
            for (int i=0; i<mid; i++)
            {
                LL temp=wn*p[mid+i]%M;
                p[mid+i]=(p[i]-temp+M)%M;
                p[i]=(p[i]+temp)%M;
                wn=wn*e%M;
            }
        }
    }
}

void NTT()
{
    DFT(A,1);
    DFT(B,1);
    for (int i=0; i<N; i++) A[i]=A[i]*B[i]%M;
    DFT(A,-1);

    LL inv=Pow(N,M-2,M);
    for (int i=0; i<N; i++) A[i]=A[i]*inv%M;
}

LL C(Big_Int k,Big_Int m)
{
    LL ck=Change(k);
    LL cm=Change(m);
    if ( ck>=0 && cm>=0 ) return CI(ck,cm);

    P x=Div(k,p);
    P y=Div(m,p);
    LL v=C(x.first,y.first);
    v=v*CI(x.second,y.second)%p;
    return v;
}

void Solve(Big_Int k,Big_Int m)
{
    LL ck=Change(k);
    LL cm=Change(m);
    if ( ck>=0 && cm>=0 )
    {
        for (int i=0; i<N; i++) A[i]=0;
        //this is not the only exit way,don't clear A[] here!!!
        for (int i=cm; i<=ck; i++) A[ id[ CI(i,cm) ] ]++;
        return;
    }

    P x=Div(k,p);
    P y=Div(m,p);
    if (x.first.num[0])
    {
        Big_Int z=x.first;
        Dec(z);
        Solve(z,y.first);
        for (int i=0; i<N; i++) B[i]=0;
        for (int i=y.second; i<p; i++) B[ id[ CI(i,y.second) ] ]++;
        NTT();
        for (int i=p-1; i<N; i++) A[i%(p-1)]=(A[i%(p-1)]+A[i])%M,A[i]=0;
    }

    LL v=C(x.first,y.first);
    if (v) for (int i=y.second; i<=x.second; i++)
    {
        LL &q=A[ id[ CI(i,y.second)*v%p ] ];
        q=(q+1LL)%M;
    }
}

int main()
{
    //freopen("86.in","r",stdin);
    //freopen("86.out","w",stdout);

    scanf("%lld",&p);

    pg=Get(p);
    LL v=1;
    for (int i=0; i<p-1; i++)
    {
        id[v]=i;
        rid[i]=v;
        v=v*pg%p;
    }

    fac[0]=1;
    for (LL i=1; i<p; i++) fac[i]=fac[i-1]*i%p;
    for (int i=0; i<p; i++) nfac[i]=Pow(fac[i],p-2LL,p);

    N=1,Lg=0;
    while (N<=2*p+2) N<<=1,Lg++;
    for (int i=0; i<N; i++)
        for (int j=0; j<Lg; j++)
            if (i&(1<<j)) Rev[i]|=(1<<(Lg-j-1));

    Read(n);
    n1=n;
    Read(l);
    Read(r);

    P x=Div(r,M);
    ans[0]=x.second;
    ans[0]=(ans[0]+1LL)%M; //[0,r] have r+1 numbers!!!
    Solve(r,n);
    for (int i=0; i<p-1; i++) ans[ rid[i] ]=A[i];

    if (l.num[0])
    {
        Dec(l);
        x=Div(l,M);
        ans[0]=(ans[0]-x.second+M)%M;
        ans[0]=(ans[0]-1LL+M)%M;
        for (int i=0; i<N; i++) A[i]=0; //clear A[] here!!!
        Solve(l,n1);

        for (int i=0; i<p-1; i++)
        {
            int v=rid[i];
            ans[v]=(ans[v]-A[i]+M)%M;
        }
    }

    for (int i=1; i<p; i++) ans[0]=(ans[0]-ans[i]+M)%M; //write outside the "if"!!!
    for (int i=0; i<p; i++) printf("%lld\n",ans[i]);

    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章