Hdu-5519 Kykneion asma(狀壓DP+容斥)

On the last day before the famous mathematician Swan's death, he left a problem to the world: Given integers nn and aiai for 0i40≤i≤4, calculate the number of nn-digit integers which have at most aiai-digit ii in its decimal representation (and have no 5,6,7,85,6,7,8 or 99). Leading zeros are not allowed in this problem.
InputThere is one integer T (1<T10)T (1<T≤10) in the beginning of input, which means that you need to process TT test cases. In each test case, there is one line containing six integers representing nn and a0a0 to a4a4, where 2n150002≤n≤15000 and 0ai300000≤ai≤30000. OutputFor each test case, you should print first the identifier of the test case and then the answer to the problem, module 109+7109+7. Sample Input
10
5 0 1 2 3 4
5 1 1 1 1 1
5 2 2 2 2 2
5 3 3 3 3 3
5 3 2 1 3 2
5 3 2 0 0 0
5 0 0 0 5 0
7000 41 2467 6334 2500 3169
7000 7724 3478 5358 2962 464
7000 5705 4145 7281 827 1961
Sample Output
Case #1: 535
Case #2: 96
Case #3: 1776
Case #4: 2416
Case #5: 1460
Case #6: 4
Case #7: 1
Case #8: 459640029
Case #9: 791187801

Case #10: 526649529

題意: 給你包括0的5個數字,又告訴你每個數字的最多使用次數,問你由這5個數字組成的不同n位數(無前導0)有多少個。

分析:這題可以直接上母函數+FFT,但是不能直接用NTT因爲模數不是費馬素數,然後在網上看到了還有這種容斥+狀壓DP的做法,容斥很好想,但是難點在怎麼用狀壓DP求給定數字集超出限制的方案數,這裏有一個小trick就是如果我們想讓某個數字i超出限制,那麼直接放入a[i]+1個數字就行了,f[i][mask][j]表示當前在第i位,當前 mask 這些數字超出了限制,且最終超出限制的數字爲j個的方案數,那麼有:

f[i][mask][j] = f[i-1][mask][j]*(5 - j + count(mask))+sigma(f[i-a[k]-1][mask xor (1<<k)][j]*C(i-1,a[k]))

加上f[i-1][mask][j]這部分很容易理解,相當於枚舉第i位的數字,右邊的部分相當於在此刻剛好超過a[k]這個限制時的方案數。

#include <bits/stdc++.h>
#define MOD 1000000007
using namespace std;
typedef long long ll;
int T,Time,n,cnt[32],a[6];
ll jc[15005],inv[15005],f[15005][32][6];
void exgcd(ll a,ll b,ll &g,ll &x,ll &y)
{
    if(!b) g=a,x=1,y=0;
    else
    {
        exgcd(b,a%b,g,y,x);
        y-=a/b*x;
    }
}
ll Inv(ll a,ll n)
{
    ll d,x,y;
    exgcd(a,n,d,x,y);
    return d == 1 ? (x+n)%n : -1;
}
int lowbit(int x)
{
    return x & -x;
}
ll c(int x,int y)
{
    return (jc[x]*inv[y] % MOD)*inv[x-y] % MOD;
}
void add(ll &x,ll y)
{
    y %= MOD;
    x = (x + y) % MOD;
}
ll got(int n)
{
    memset(f,0,sizeof(f));
    for(int i = 1;i <= 5;i++) f[0][0][i] = 1;
    for(int j = 1;j <= 5;j++)
     for(int i = 1;i <= n;i++)
      for(int mask = 0;mask < 32;mask++)
      if(cnt[mask] <= j)
      {
         add(f[i][mask][j],f[i-1][mask][j] * (5 - j + cnt[mask]));
         for(int k = 1;k <= 5;k++)
         if(((1<<(k-1)) & mask) && i >= a[k] + 1) add(f[i][mask][j],f[i - a[k] - 1][mask - (1<<(k-1))][j]*c(i-1,a[k]));
      }
    ll temp = 1;
    for(int i = 1;i <= n;i++) temp = temp*5 % MOD;
    for(int mask = 1;mask < 32;mask++)
     if(cnt[mask] & 1) temp = (temp - f[n][mask][cnt[mask]] + MOD) % MOD;
     else temp = (temp + f[n][mask][cnt[mask]]) % MOD;
    return temp;
}
int main()
{
    jc[0] = inv[0] = 1;
    for(int i = 1;i <= 15000;i++) jc[i] = jc[i-1] * i % MOD,inv[i] = Inv(jc[i],MOD);
    for(int i = 1;i < 32;i++) cnt[i] = cnt[i - lowbit(i)] + 1;
    cin.sync_with_stdio(false);
    cin>>T;
    while(T--)
    {
        cin>>n;
        for(int i = 1;i <= 5;i++) cin>>a[i];
        if(!a[1]) cout<<"Case #"<<++Time<<": "<<got(n)<<endl;
        else
        {
            int temp = got(n);
            a[1]--;
            temp = (temp - got(n-1) + MOD) % MOD;
            cout<<"Case #"<<++Time<<": "<<temp<<endl;
        }
    }
}



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章