Rikka with Array
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total Submission(s): 301 Accepted Submission(s): 119
Yuta has an array A of length n,and the ith element of A is equal to the sum of all digits of i in binary representation. For example,A[1]=1,A[3]=2,A[10]=2.
Now, Yuta wants to know the number of the pairs (i,j)(1≤i<j≤n) which satisfy A[i]>A[j].
It is too difficult for Rikka. Can you help her?
For each testcase, the first line contains a number n(n≤10300).
題目描述:給定一個數n(0 < n < 10^300),問有多少個數對(i , j),滿足1<=i < j <= n且A[i] > A[j],其中A[x]是x化成二進制之後中1的個數
思路:定義dp[len][sum][limit]表示當前枚舉到第len位,已經枚舉出的兩個數的數位中1的個數差爲i - j + 1000(加1000是因爲差可能爲負),枚舉到第len位時i和j的狀態爲limit時合法的數對,其中
limit == 0 表示i<j < n
limit == 1 表示i<j = n
limit == 2 表示i= j < n
limit == 3 表示i =j = n
之所以要定義這四種狀態,是因爲這四種狀態下在填i和j的第len-1位的時候的限制不同,轉移方程見代碼
收穫:1、對於兩個數的數位dp,其實道理和一個數時是一樣的,每次枚舉兩個數的這一位要放什麼數,只是貼上界更麻煩
2、大整數轉二進制並不需要高精度,從別人代碼裏學來了一個辦法替掉了自己原本代碼裏的高精度
#pragma warning(disable:4786)
#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<vector>
#include<cmath>
#include<string>
#include<sstream>
#include<bitset>
#define LL long long
#define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i)
#define mem(a,x) memset(a,x,sizeof(a))
#define lson l,m,x<<1
#define rson m+1,r,x<<1|1
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 998244353;
const double PI = acos(-1.0);
const double eps=1e-6;
const int maxn = 1000;
const int base = 1000;
LL dp[4][maxn][maxn * 2] ;
char s[maxn];
int num[maxn] , bits[maxn];
LL dfs(int len , int sum , int limit , int zero)
{
if(sum > base && sum - base > len - 1) return 0;
if(len == 1) return (sum < base) && (!zero);
if(dp[limit][len][sum] != -1) return dp[limit][len][sum];
LL res = 0;
for(int i = 0 ; i < 2 ; i++){
for(int j = 0 ; j < 2 ; j++){
if(limit == 0){
res += dfs(len - 1 , sum + i - j , 0 , zero && j == 0);
res %= mod;
}
else if(limit == 1){
if(i > bits[len - 1]) continue;
res += dfs(len - 1 , sum + i - j , i == bits[len - 1] ? 1 : 0, zero && j == 0);
res %= mod;
}
else if(limit == 2){
if(i < j) continue;
res += dfs(len - 1 , sum + i - j , i == j ? 2 : 0, zero && j == 0);
res %= mod;
}
else{
if(i == bits[len - 1]){
if(j < i) res += dfs(len - 1 , sum + i - j , 1, zero && j == 0);
if(j == i) res += dfs(len - 1 , sum + i - j , 3, zero && j == 0);
}
else if(i < bits[len - 1]){
if(j < i) res += dfs(len - 1 , sum + i - j , 0, zero && j == 0);
if(j == i) res += dfs(len - 1 , sum + i - j , 2, zero && j == 0);
}
res %= mod;
}
}
}
dp[limit][len][sum] = res;
return res;
}
//convert函數用於大十進制數轉二進制數,其中len是大整數的位數,大整數已經存在num數組裏了
//如大整數爲14236,則len = 5 , num[1] = 6 , num[2] = 3 , num[3] = 2 , num[4] = 4,num[5] = 1,num[0]只用於判斷
//設大整數爲n,大整數位數爲len,時間複雜度是(logn * len)
int convert(int len)
{
int m = 1;
while(len){
for(int i = len ; i ; i--){
num[i - 1] += (num[i] & 1) * 10;
num[i] >>= 1;
} //這個操作一結束num中存的大整數就變成了原來的一半
bits[m++] = (num[0] != 0);
num[0] = 0;
if(!num[len])
--len;
}
return m;
}
LL solve(int len)
{
int cnt = convert(len);
LL ret = dfs(cnt , base , 3 , 1);
return ret;
}
int main()
{
int T;
scanf("%d" , &T);
mem(dp , -1);
while(T--){
mem(dp[1] , -1);
mem(dp[3] , -1);
scanf("%s" , s + 1);
int len = strlen(s + 1);
for(int i = 1 ; i<= len ; i++){
num[i] = s[len - i + 1] - '0';
}
LL ans = solve(len);
printf("%lld\n",ans);
}
return 0;
}