2020 CCPC Wannafly Winter Camp Day2 B.薩博的方程式(數位DP)

題意:

薩博有個方程式:

x1 xor x2  xor xn=k(xi[0,mi])x_1\ xor\ x_2\ \dots\ xor \ x_n=k\quad (x_i\in[0,m_i])
n50, k,mi<231n\le50,\ k,m_i<2^{31}
求有多少組x在滿足條件的情況下使得等式成立,答案對1e9+71e9+7取模。

題解:

以下爲邦邦老師的ppt,這一部分講的挺清晰的:
在這裏插入圖片描述
dp狀態的解釋也很清晰:
在這裏插入圖片描述
這樣的做法複雜度是O(Tn2log(m))O(T*n^2log(m))的。
我自己做的時候一開始不知道得到dp這個dp狀態之後如何得到這一位不全選1的解對答案的貢獻,後來想明白了:假設這一位有xx11,那麼F(x,j)F(x,j)對答案的貢獻是F(x,j)/2posF(x,j)/2^{pos},pos指的是最高位的位數。能做貢獻的前提是j的奇偶性和k的這一位是一樣的。
爲什麼呢。因爲F(x,j)F(x,j)相當於是總的方案數,然後我們因爲有一個k的限制,所以可以讓其中一個高位本可選1但是選了0的數字去和其他的湊,也就是說其他的定了它也就定了,不能亂動。這和它原本可以選擇2pos2^{pos}種方案相比,變成了只有一種選擇。
想清楚這一點之後我們發現,實際上不用記錄選擇了幾個1,而只要記錄選擇了奇數個1還是偶數個1,是否爲全選即可得到貢獻,這樣優化了一層記錄選擇1的個數的循環。dp(i,0/1,0/1)dp(i,0/1,0/1)表示考慮前i個1,用了偶/奇數個1,未全選/全選了1的方案數。
這樣複雜度可以優化爲O(Tnlog(m))O(T*nlog(m))。這樣數據範圍的nn就可以出到1e51e5啦~

未優化的代碼:

#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][55];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
    ll res = 1;
    while(b){
        if(b&1) res = res*a%mod;
        a = a*a%mod;
        b >>= 1;
    }return res;
}
ll sol(int pos){
    //cout<<"pos:"<<pos<<endl;
    if(pos < 0) return 1;
    ll res = 0;
    memset(dp, 0, sizeof dp);
    dp[0][0] = 1;
    int cur = 0;
    for(int i = 1; i <= n; ++i){
        if(m[i]>>pos&1){
            cur++;
            dp[cur][0] = dp[cur-1][0]*(1LL<<pos)%mod;
            for(int j = 1; j <= cur; ++j){
                dp[cur][j] = (dp[cur-1][j]*(1LL<<pos)%mod + dp[cur-1][j-1]*(m[i]-(1LL<<pos)+1)%mod)%mod;
            }
        }else{
            for(int j = 0; j <= cur; ++j) dp[cur][j] = dp[cur][j]*(m[i]+1)%mod;
        }
    }
    ll inv = qm(1LL<<pos, mod-2);
    for(int i = (k>>pos&1); i < cur; i+=2){
        res = (res + dp[cur][i]*inv)%mod;
    }
    //cout<<"res:"<<res<<endl;
    if((cur&1) == (k>>pos&1)){
        for(int i = 1; i <= n; ++i){
            if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
        }
        return (res + sol(pos-1))%mod;
    }else return res;
}
int main()
{
    while(scanf("%d%lld", &n, &k)!=EOF){
        for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
        ll ans = sol(31);
        ans = (ans + mod)%mod;
        cout<<ans<<endl;
    }
}

優化後的代碼

#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][2][2];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
    ll res = 1;
    while(b){
        if(b&1) res = res*a%mod;
        a = a*a%mod;
        b >>= 1;
    }return res;
}
ll sol(int pos){
    //cout<<"pos:"<<pos<<endl;
    if(pos < 0) return 1;
    ll res = 0;
    memset(dp, 0, sizeof dp);
    dp[0][0][1] = 1;
    int cur = 0;
    for(int i = 1; i <= n; ++i){
        if(m[i]>>pos&1){
            cur++;
            dp[cur][0][0] =( (m[i]-(1LL<<pos)+1)*dp[cur-1][1][0]%mod + (1LL<<pos)*(dp[cur-1][0][0]+dp[cur-1][0][1])%mod )%mod;
            dp[cur][0][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][1][1]%mod;
            dp[cur][1][0] = ( (m[i]-(1LL<<pos)+1)*dp[cur-1][0][0]%mod + (1LL<<pos)*(dp[cur-1][1][0] + dp[cur-1][1][1])%mod )%mod;
            dp[cur][1][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][0][1]%mod;
        }else{
            dp[cur][0][0] = (dp[cur][0][0]*(m[i]+1))%mod;
            dp[cur][1][0] = (dp[cur][1][0]*(m[i]+1))%mod;
            dp[cur][0][1] = (dp[cur][0][1]*(m[i]+1))%mod;
            dp[cur][1][1] = (dp[cur][1][1]*(m[i]+1))%mod;
        }
    }
    ll inv = qm(1LL<<pos, mod-2);
    res = dp[cur][k>>pos&1][0]*inv%mod;
    if((cur&1) == (k>>pos&1)){
        for(int i = 1; i <= n; ++i){
            if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
        }
        return (res + sol(pos-1))%mod;
    }else return res;

}
int main()
{
    while(scanf("%d%lld", &n, &k)!=EOF){
        for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
        ll ans = sol(31);
        ans = (ans + mod)%mod;
        cout<<ans<<endl;
    }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章