codeforces 895C

(狀壓dp)
題意:給定一個集合,裏面包括n(1n105) 個數字a[i](1a[i]70) ,求出這個集合中有多少子集,使得子集內部所有數字的乘積爲平方數。

思路:觀察到a[i] 的範圍比n 要小很多,於是我們可以用一個小數組保存:大小爲i 的數字出現過幾次,然後枚舉i(1i70) 。計數問題很有可能是dp,枚舉i 時最關鍵是要想到怎麼保存i 之前數字的乘積有哪些?這裏由於平方數的在質因子分解後,每個質因子的冪一定是偶數,於是想到只保存質因子取二進制最低位後的乘積,然而乘積太大也沒法保存。於是觀察到70以內的質因子只有不到20個,就可以考慮用狀態壓縮保存。然後dp方程就很好想了。(最後記得優化一下內存空間)
cnt[i]=0 時直接將i1 轉移到i 即可。
cnt[i]0 時:

{dp[i][j]=dp[i][j]+dp[i1][j]2cnt[i]1dp[i][jmask[i]]=dp[i][jmask[i]]+dp[i1][j]2cnt[i]1

代碼:
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define LL long long

using namespace std;
const int maxn = 100050;
const LL mod = 1e9 + 7;

LL dp[2][1<<19], pw[maxn];
int a[maxn], cnt[80], bitmask[80];
vector<int> prim;

void init() {
    // get 2^n
    pw[0] = 1;
    for(int i=1; i<maxn; i++)
        pw[i] = 2LL * pw[i-1] % mod;
    prim.clear();
    memset(cnt, 0, sizeof(cnt));
    memset(bitmask, 0, sizeof(bitmask));
    memset(dp, 0, sizeof(dp));
    // get prime number
    for(int i=2; i<71; i++) {
        bool flag = 1;
        for(int j=2; j*j<=i; j++)
            if(i%j == 0) flag = 0;
        if(flag) prim.push_back(i);
    }
    // get bitmask for each i
    for(int i=2; i<71; i++) {
        int t = i;
        for(int j=0; j<(int)prim.size(); j++) {
            int num = 0, div = prim[j];
            while(t%div == 0) {
                t /= div;
                num ++;
            }
            if(num&1)
                bitmask[i] += 1<<j;
        }
    }
}

int main() {
    //freopen("test.txt","r",stdin);
    init();
    int n;
    scanf("%d",&n);
    for(int i=0; i<n; i++) {
        scanf("%d",&a[i]);
        cnt[a[i]] ++;
    }
    // solve
    int sz = (int)prim.size();
    dp[0][0] = 1;
    for(int i=1; i<71; i++) {
        int mask = bitmask[i], cur = i&1, last = !cur;
        for(int j=0; j<(1<<sz); j++) {
            if(cnt[i] == 0)
                dp[cur][j] = dp[last][j];
            else {
                dp[cur][j] = (dp[cur][j] + dp[last][j]*pw[cnt[i]-1]%mod) % mod;
                dp[cur][j^mask] = (dp[cur][j^mask] + dp[last][j]*pw[cnt[i]-1]%mod) % mod;
            }
        }
        for(int j=0; j<(1<<sz); j++)
            dp[last][j] = 0;
        //printf("%d : %I64d\n",i,dp[i][0]);
    }
    printf("%I64d\n",(dp[0][0]-1+mod)%mod);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章