BZOJ 4338 糖果(擴展Lucas定理+CRT)

題目鏈接:BZOJ 4338

題目大意:用數字1~k填一個n*m的表格,每種數字可用任意次,要求每行數字1~m列單調不減,任意兩行不完全相同,求方案數對P取模的值。

題解:擴展Lucas+CRT模板題,板子還不是太熟悉,貼到這裏方便複習,有空回來加點註釋。最後答案的式子比較容易得到,是 ACm+k1mn  mod P

code

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 100005
using namespace std;
typedef long long ll;
inline int read()
{
    char c=getchar(); int num=0,f=1;
    while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
    while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
    return num*f;
}
int n,m,k,mod,d[15],D[15],mo,phi[15],ans[15],num,jc[N],inv[N];
inline int add(int x,int y,int p) { x+=y; if (x>=p) x-=p; return x; }
inline int ksm(int a,int b,int p)
{
    int ret=1;
    for (;b;b>>=1,a=1ll*a*a%p)
     if (b&1) ret=1ll*ret*a%p;
    return ret;
}
struct newnum{
    int val,tmp;
    newnum (int _val,int _tmp) { val=_val; tmp=_tmp; }
    newnum operator * (const newnum U) const
    {
        return newnum(1ll*val*U.val%D[mo],tmp+U.tmp);
    }
    newnum operator / (const newnum U) const
    {
        return newnum(1ll*val*ksm(U.val,phi[mo]-1,D[mo])%D[mo],tmp-U.tmp);
    }
};
inline void dvd(int p)
{
    for (int i=2;1ll*i*i<=p;i++)
     if (p%i==0)
     {
         d[++num]=i,D[num]=1;
         while (p%i==0) p/=i,D[num]*=i;
         phi[num]=D[num]/d[num]*(d[num]-1); 
         if (p==1) break;
     } 
    if (p!=1) d[++num]=D[num]=p,phi[num]=p-1;
}
inline newnum getjc(int x)
{
    newnum now=newnum(1,0);
    if (x>=d[mo]) now=now*getjc(x/d[mo]),now.tmp+=x/d[mo];
    if (x>=D[mo]) now=now*newnum(ksm(jc[D[mo]-1],x/D[mo],D[mo]),0);
    now=now*newnum(jc[x%D[mo]],0); 
    return now;
}
inline newnum getC(int n,int m)
{
    return getjc(n)/getjc(m)/getjc(n-m);
}
inline int getP(int n,int m,int p)
{
    int now=n;
    for (int i=2;i<=m;i++) now=add(now,p-1,p),n=1ll*n*now%p;
    return n;
}
void solve(int x)
{
    mo=x,jc[0]=1; int p=D[x];
    if (d[mo]==p&&p>m)
    {
        inv[1]=1;
        for(int i=2;i<=m;i++) inv[i]=p-1ll*inv[p%i]*(p/i)%p; 
        ans[mo]=1; int now=k+m-1;
        for (int i=1;i<=m;i++)
         ans[mo]=1ll*ans[mo]*now%p*inv[i]%p,now=add(now,p-1,p);
        ans[mo]=getP(ans[mo],n,p);
    }
    else
    {
        for (int i=1;i<p;i++)
        {
            jc[i]=jc[i-1];
            if (i%d[mo]) jc[i]=1ll*jc[i]*i%p;
        }
        newnum now=getC(k+m-1,m);
        ans[mo]=1ll*now.val*ksm(d[mo],now.tmp,p)%p;
        ans[mo]=getP(ans[mo],n,p);
    }
}
inline int CRT()
{
    int ret=0;
    for(int i=1;i<=num;i++)
    {
        mo=i;
        ret=add(ret,1ll*(mod/D[i])*ksm(mod/D[i],phi[i]-1,D[i])%mod*ans[i]%mod,mod);
    }
    return ret;
} 
int main()
{
    n=read(); m=read(); k=read(); mod=read();
    dvd(mod);
    for (int i=1;i<=num;i++) solve(i);
    printf("%d",CRT());
    return 0;
}

另一份簡潔一些的模板(BZOJ_3129_方程)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long LL;
inline int read()
{
    char c=getchar(); int num=0,f=1;
    while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
    while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
    return num*f;
}
int T,P,bin[15],n,n1,n2,m,a[15],p[7],powp[7],jc[10205],ans[7],anss;
int ksm(int a,int b,int p)
{
    int ret=1;
    for (;b;b>>=1,a=1ll*a*a%p)
     if (b&1) ret=1ll*ret*a%p;
    return ret;
}
int exgcd(int a,int b,int &x,int &y)
{
    if (!b) x=1,y=0;
     else exgcd(b,a%b,y,x),y-=(a/b)*x;
}
int inv(int a,int b)
{
    if (!a) return 0;
    int x=0,y=0; exgcd(a,b,x,y);
    x=(x%b+b)%b;
    return x;
}
int fac(int n,int pi,int pk)
{
    if (n==0) return 1; int ret=1;
    if (n/pk) ret=ksm(jc[pk],n/pk,pk);
    ret=1ll*ret*jc[n%pk]%pk;
    return 1ll*ret*fac(n/pi,pi,pk)%pk;
}
int comb(int n,int m,int pi,int pk)
{
    if (n<m) return 0; int k=0;
    int a=fac(n,pi,pk),b=fac(m,pi,pk),c=fac(n-m,pi,pk);
    for (int i=n;i;i/=pi) k+=i/pi;
    for (int i=m;i;i/=pi) k-=i/pi;
    for (int i=n-m;i;i/=pi) k-=i/pi;
    int ret=1ll*a*inv(b,pk)%pk*inv(c,pk)%pk*ksm(pi,k,pk)%pk;
    return ret;
}
inline int add(int a,int b) { a+=b; if (a>=P) a-=P; return a; }
inline int sub(int a,int b) { a-=b; if (a<0) a+=P; return a; }
int main()
{
    bin[0]=1; for (int i=1;i<=10;i++) bin[i]=bin[i-1]<<1;
    T=read(); P=read(); int tmp=P;
    for (int i=2;i*i<=tmp;i++)
     if (tmp%i==0)
     {
         p[++p[0]]=i; powp[p[0]]=1;
         while (tmp%i==0) powp[p[0]]*=i,tmp/=i;
     }
    if (tmp>1) p[++p[0]]=powp[p[0]]=tmp;
    while (T--)
    {
        n=read(); n1=read(); n2=read(); m=read()-n; anss=0;
        for (int i=1;i<=n1;i++) a[i]=read();
        for (int i=1;i<=n2;i++) m-=read()-1;
        if (m<0) { printf("0\n"); continue; }
        for (int cur=1;cur<=p[0];cur++)
        {
            jc[0]=1; ans[cur]=0;
            for (int j=1;j<=powp[cur];j++)
            {
                jc[j]=jc[j-1];
                if (j%p[cur]) jc[j]=1ll*jc[j]*j%powp[cur];
            }
            for (int sta=0;sta<bin[n1];sta++)
            {
                int cnt=0,ret; tmp=m;
                for (int i=1;i<=n1;i++)
                 if (sta&bin[i-1]) cnt++,tmp-=a[i];
                if (tmp<0) continue;
                ret=comb(tmp+n-1,tmp,p[cur],powp[cur]);
                if (cnt&1) ans[cur]=sub(ans[cur],ret);
                 else ans[cur]=add(ans[cur],ret);
            }
        }
        for (int cur=1;cur<=p[0];cur++)
        {
            int ret=1ll*ans[cur]*(P/powp[cur])%P*inv(P/powp[cur],powp[cur])%P;
            anss=add(anss,ret);
        }
        printf("%d\n",anss);
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章