題意:
薩博有個方程式:
求有多少組x在滿足條件的情況下使得等式成立,答案對取模。
題解:
以下爲邦邦老師的ppt,這一部分講的挺清晰的:
dp狀態的解釋也很清晰:
這樣的做法複雜度是的。
我自己做的時候一開始不知道得到dp這個dp狀態之後如何得到這一位不全選1的解對答案的貢獻,後來想明白了:假設這一位有個,那麼對答案的貢獻是,pos指的是最高位的位數。能做貢獻的前提是j的奇偶性和k的這一位是一樣的。
爲什麼呢。因爲相當於是總的方案數,然後我們因爲有一個k的限制,所以可以讓其中一個高位本可選1但是選了0的數字去和其他的湊,也就是說其他的定了它也就定了,不能亂動。這和它原本可以選擇種方案相比,變成了只有一種選擇。
想清楚這一點之後我們發現,實際上不用記錄選擇了幾個1,而只要記錄選擇了奇數個1還是偶數個1,是否爲全選即可得到貢獻,這樣優化了一層記錄選擇1的個數的循環。表示考慮前i個1,用了偶/奇數個1,未全選/全選了1的方案數。
這樣複雜度可以優化爲。這樣數據範圍的就可以出到啦~
未優化的代碼:
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][55];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
ll res = 1;
while(b){
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}return res;
}
ll sol(int pos){
//cout<<"pos:"<<pos<<endl;
if(pos < 0) return 1;
ll res = 0;
memset(dp, 0, sizeof dp);
dp[0][0] = 1;
int cur = 0;
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1){
cur++;
dp[cur][0] = dp[cur-1][0]*(1LL<<pos)%mod;
for(int j = 1; j <= cur; ++j){
dp[cur][j] = (dp[cur-1][j]*(1LL<<pos)%mod + dp[cur-1][j-1]*(m[i]-(1LL<<pos)+1)%mod)%mod;
}
}else{
for(int j = 0; j <= cur; ++j) dp[cur][j] = dp[cur][j]*(m[i]+1)%mod;
}
}
ll inv = qm(1LL<<pos, mod-2);
for(int i = (k>>pos&1); i < cur; i+=2){
res = (res + dp[cur][i]*inv)%mod;
}
//cout<<"res:"<<res<<endl;
if((cur&1) == (k>>pos&1)){
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
}
return (res + sol(pos-1))%mod;
}else return res;
}
int main()
{
while(scanf("%d%lld", &n, &k)!=EOF){
for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
ll ans = sol(31);
ans = (ans + mod)%mod;
cout<<ans<<endl;
}
}
優化後的代碼
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid+1, r
using namespace std;
const ll mod = 1e9 + 7;
ll dp[55][2][2];
int n;
ll k;
ll m[55];
ll qm(ll a, ll b){
ll res = 1;
while(b){
if(b&1) res = res*a%mod;
a = a*a%mod;
b >>= 1;
}return res;
}
ll sol(int pos){
//cout<<"pos:"<<pos<<endl;
if(pos < 0) return 1;
ll res = 0;
memset(dp, 0, sizeof dp);
dp[0][0][1] = 1;
int cur = 0;
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1){
cur++;
dp[cur][0][0] =( (m[i]-(1LL<<pos)+1)*dp[cur-1][1][0]%mod + (1LL<<pos)*(dp[cur-1][0][0]+dp[cur-1][0][1])%mod )%mod;
dp[cur][0][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][1][1]%mod;
dp[cur][1][0] = ( (m[i]-(1LL<<pos)+1)*dp[cur-1][0][0]%mod + (1LL<<pos)*(dp[cur-1][1][0] + dp[cur-1][1][1])%mod )%mod;
dp[cur][1][1] = (m[i]-(1LL<<pos)+1)*dp[cur-1][0][1]%mod;
}else{
dp[cur][0][0] = (dp[cur][0][0]*(m[i]+1))%mod;
dp[cur][1][0] = (dp[cur][1][0]*(m[i]+1))%mod;
dp[cur][0][1] = (dp[cur][0][1]*(m[i]+1))%mod;
dp[cur][1][1] = (dp[cur][1][1]*(m[i]+1))%mod;
}
}
ll inv = qm(1LL<<pos, mod-2);
res = dp[cur][k>>pos&1][0]*inv%mod;
if((cur&1) == (k>>pos&1)){
for(int i = 1; i <= n; ++i){
if(m[i]>>pos&1) m[i] ^= (1LL<<pos);
}
return (res + sol(pos-1))%mod;
}else return res;
}
int main()
{
while(scanf("%d%lld", &n, &k)!=EOF){
for(int i = 1; i <= n; ++i) scanf("%lld", &m[i]);
ll ans = sol(31);
ans = (ans + mod)%mod;
cout<<ans<<endl;
}
}