BZOJ 3992 [SDOI2015]序列統計

NTT+矩陣快速冪

懶得寫了,orz鏈接:http://blog.csdn.net/ied98/article/details/46852805

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 17005
#define MOD 1004535809
using namespace std;
namespace runzhe2000
{
    typedef long long ll;
    int n, m, x, S, len, a[N], ind[N], g, prime[N], f[N], ff[N], h[N];
    int fpow(int a, int b, int mod)
    {
        int r = 1;
        for(; b; b>>=1)
        {
            if(b&1) r = (ll)r*a%mod;
            a = (ll)a*a%mod;
        }
        return r;
    }
    int get_g(int p)
    {
        int tmp = p-1, pcnt = 0;;
        for(int i = 2; tmp != 1; i++)
            if(tmp % i == 0)
            {
                prime[++pcnt] = i;
                for(; tmp % i == 0; tmp /= i);
            }
        for(int i = 1; ; i++)
        {
            bool flag = 1;
            for(int j = 1; j <= pcnt; j++)
                if(fpow(i, (p-1)/prime[j], p) == 1) {flag = 0; break;}
            if(flag) return i;
        }
        return -233;
    }
    void NTT(int *f, int g, int n, int re)
    {
        for(int i = 0, j = 0; i < n; i++)
        {
            if(i < j) swap(f[i], f[j]);
            for(int l = n>>1; (j^=l) < l; l >>= 1);
        }
        for(int i = 2; i <= n; i <<= 1)
        {
            int wn = fpow(g, (MOD-1)/i, MOD), mid = i>>1;
            if(re) wn = fpow(wn, MOD-2, MOD);//
            for(int j = 0; j < n; j += i)
            {
                int w = 1;
                for(int k = 0; k < mid; k++)
                {
                    int x = f[j+k], y = (ll)w*f[j+k+mid]%MOD;
                    f[j+k] = (x+y)%MOD;
                    f[j+k+mid] = (x-y)%MOD;
                    w = (ll) w * wn % MOD;
                }
            }
        }
        if(re) //
        {
            int inv = fpow(n, MOD-2, MOD);
            for(int i = 0; i < n; i++) f[i] = (ll) f[i] * inv % MOD;
        }
    }
    void mul(int *a, int *b) // a = a * b
    {
        NTT(a, 3, len, 0);
        NTT(b, 3, len, 0);
        for(int i = 0; i < len; i++) a[i] = (ll)a[i] * b[i] % MOD;
        NTT(a, 3, len, 1);
        for(int i = m-1; i < len; i++) (a[i%(m-1)] += a[i]) %= MOD, a[i] = 0; 
    }
    void main()
    {
        scanf("%d%d%d%d",&n,&m,&x,&S);
        for(int i = 1; i <= S; i++) scanf("%d",&a[i]);
        g = get_g(m); 
        for(int i = 1, j = 0; j < m-1; j++) ind[i] = j, i = (ll) i * g % m;
        for(int i = 1; i <= S; i++) if(a[i]) f[ind[a[i]]]++; 
        len = 1; for(; len < (m+m); len <<=1)
        h[0] = 1;
        for(; n; n>>=1)
        {
            if(n&1) 
            {
                memcpy(ff, f, sizeof(f));
                mul(h, ff);
            }
            memcpy(ff,f,sizeof(f));
            mul(f,ff);
        }
        printf("%d\n",(h[ind[x]]+MOD)%MOD);
    }
}
int main()
{
    runzhe2000::main();
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章