bzoj4818 [Sdoi2017]序列計數

傳送門

矩陣優化dp、容斥原理。
先寫出dp轉移方程:f[i&1][(j+k)%p][1]=f[(i&1)^1][j][0]+f[(i&1)^1][j][1] (這是我當時在考場上寫的20分暴力)
上面的方程對有無素數進行了分類,其實可以進行容斥,用所有方案數減去沒有素數的方案數。又因爲題目要求序列之和爲p的倍數,所以上面的方程對序列之和模p的餘數進行了存儲。而餘數讓我們想到了什麼呢?矩陣!矩陣可以對不同的餘數進行記錄,並且可以很好地處理轉移。題目中p的最大值爲100,所以我們可以開100*100的矩陣,矩陣中存儲餘數爲某一定值的時候的方案數。(這種方法是我自己yy的,所以我還是舉個例子比較好)

以p=4爲例:
這裏寫圖片描述
上面左邊的矩陣是轉移矩陣,中間的是初始矩陣,右邊的是結果矩陣,轉移矩陣中寫的數字是序列之和對p取模後的餘數,但事實上轉移矩陣中存儲的是序列之和爲該餘數時的方案數,初始矩陣和結果矩陣中的f[i]表示序列之和模p後餘數爲i的方案數
我們可以看出,通過將初始矩陣不斷地用轉移矩陣進行矩陣乘法,就可以得到將序列增長後可行的方案數。所以我們對轉移矩陣進行n-1次乘法,然後乘上初始矩陣就可以得出答案。這樣求出所有方案數和沒有素數的方案數就好了。

CODE:

#include<cstdio>
#include<cstring>
#define mod 20170408
#define N 20000005
int prime[1500000];
bool b[N];
int n,m,p,tot,ans;
struct Matrix
{
    int a[105][105];
    Matrix(){memset(a,0,sizeof(a));}
    inline Matrix operator *(const Matrix &x)const
    {
        Matrix ans;
        for(int i=0;i<p;i++)
          for(int j=0;j<p;j++)
            if(a[i][j]) for(int k=0;k<p;k++)
              ans.a[i][k]=(1ll*ans.a[i][k]+1ll*a[i][j]*x.a[j][k])%mod;
        return ans;
    }
}m1,m2,tmp1,tmp2;
inline void euler()
{
    for(int i=2;i<=m;i++)
    {
        if(!b[i]) prime[++tot]=i;
        for(int j=1;j<=tot&&i*prime[j]<=m;j++)
        {
            b[i*prime[j]]=1;
            if(i%prime[j]==0) break;
        }
    }
}
inline void Matrix_init()
{
    int num=m/p,rest=m%p;
    for(int i=0;i<p;i++)
      m1.a[0][i]=num;
    for(int i=p-1;rest;i--,rest--)
      m1.a[0][i]++;
    for(int i=1;i<p;i++)
    {
        m1.a[i][0]=m1.a[i-1][p-1];
        for(int j=1;j<p;j++)
          m1.a[i][j]=m1.a[i-1][j-1];
    }
    m2=m1;
    for(int i=1;i<=tot;i++)
    {
        int rest=prime[i]%p;
        int pos=p-rest;
        if(!rest) pos=0;
        for(int j=0;j<p;j++)
        {
            m2.a[j][pos]--;
            pos++;
            if(pos==p) pos=0;
        }
    }
    tmp1.a[0][0]=m1.a[0][0];
    tmp2.a[0][0]=m2.a[0][0];
    for(int i=1,pos=p-1;i<p;i++,pos--)
      tmp1.a[i][0]=m1.a[0][pos],
      tmp2.a[i][0]=m2.a[0][pos];
}
inline Matrix ksm(Matrix a,int b)
{
    Matrix ans=a;
    for(b--;b;b>>=1,a=a*a)
      if(b&1) ans=ans*a;
    return ans;
}
int main()
{
    scanf("%d%d%d",&n,&m,&p);
    euler(),Matrix_init();
    m1=ksm(m1,n-1),m2=ksm(m2,n-1);
    m1=m1*tmp1,m2=m2*tmp2;
    ans=m1.a[0][0]-m2.a[0][0];
    if(ans<0) ans+=mod;
    printf("%d",ans);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章