Rikka with Array
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)Total Submission(s): 301 Accepted Submission(s): 119
Yuta has an array A of length n,and the ith element of A is equal to the sum of all digits of i in binary representation. For example,A[1]=1,A[3]=2,A[10]=2.
Now, Yuta wants to know the number of the pairs (i,j)(1≤i<j≤n) which satisfy A[i]>A[j].
It is too difficult for Rikka. Can you help her?
For each testcase, the first line contains a number n(n≤10300).
题目描述:给定一个数n(0 < n < 10^300),问有多少个数对(i , j),满足1<=i < j <= n且A[i] > A[j],其中A[x]是x化成二进制之后中1的个数
思路:定义dp[len][sum][limit]表示当前枚举到第len位,已经枚举出的两个数的数位中1的个数差为i - j + 1000(加1000是因为差可能为负),枚举到第len位时i和j的状态为limit时合法的数对,其中
limit == 0 表示i<j < n
limit == 1 表示i<j = n
limit == 2 表示i= j < n
limit == 3 表示i =j = n
之所以要定义这四种状态,是因为这四种状态下在填i和j的第len-1位的时候的限制不同,转移方程见代码
收获:1、对于两个数的数位dp,其实道理和一个数时是一样的,每次枚举两个数的这一位要放什么数,只是贴上界更麻烦
2、大整数转二进制并不需要高精度,从别人代码里学来了一个办法替掉了自己原本代码里的高精度
#pragma warning(disable:4786)
#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<vector>
#include<cmath>
#include<string>
#include<sstream>
#include<bitset>
#define LL long long
#define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i)
#define mem(a,x) memset(a,x,sizeof(a))
#define lson l,m,x<<1
#define rson m+1,r,x<<1|1
using namespace std;
const int INF = 0x3f3f3f3f;
const int mod = 998244353;
const double PI = acos(-1.0);
const double eps=1e-6;
const int maxn = 1000;
const int base = 1000;
LL dp[4][maxn][maxn * 2] ;
char s[maxn];
int num[maxn] , bits[maxn];
LL dfs(int len , int sum , int limit , int zero)
{
if(sum > base && sum - base > len - 1) return 0;
if(len == 1) return (sum < base) && (!zero);
if(dp[limit][len][sum] != -1) return dp[limit][len][sum];
LL res = 0;
for(int i = 0 ; i < 2 ; i++){
for(int j = 0 ; j < 2 ; j++){
if(limit == 0){
res += dfs(len - 1 , sum + i - j , 0 , zero && j == 0);
res %= mod;
}
else if(limit == 1){
if(i > bits[len - 1]) continue;
res += dfs(len - 1 , sum + i - j , i == bits[len - 1] ? 1 : 0, zero && j == 0);
res %= mod;
}
else if(limit == 2){
if(i < j) continue;
res += dfs(len - 1 , sum + i - j , i == j ? 2 : 0, zero && j == 0);
res %= mod;
}
else{
if(i == bits[len - 1]){
if(j < i) res += dfs(len - 1 , sum + i - j , 1, zero && j == 0);
if(j == i) res += dfs(len - 1 , sum + i - j , 3, zero && j == 0);
}
else if(i < bits[len - 1]){
if(j < i) res += dfs(len - 1 , sum + i - j , 0, zero && j == 0);
if(j == i) res += dfs(len - 1 , sum + i - j , 2, zero && j == 0);
}
res %= mod;
}
}
}
dp[limit][len][sum] = res;
return res;
}
//convert函数用于大十进制数转二进制数,其中len是大整数的位数,大整数已经存在num数组里了
//如大整数为14236,则len = 5 , num[1] = 6 , num[2] = 3 , num[3] = 2 , num[4] = 4,num[5] = 1,num[0]只用于判断
//设大整数为n,大整数位数为len,时间复杂度是(logn * len)
int convert(int len)
{
int m = 1;
while(len){
for(int i = len ; i ; i--){
num[i - 1] += (num[i] & 1) * 10;
num[i] >>= 1;
} //这个操作一结束num中存的大整数就变成了原来的一半
bits[m++] = (num[0] != 0);
num[0] = 0;
if(!num[len])
--len;
}
return m;
}
LL solve(int len)
{
int cnt = convert(len);
LL ret = dfs(cnt , base , 3 , 1);
return ret;
}
int main()
{
int T;
scanf("%d" , &T);
mem(dp , -1);
while(T--){
mem(dp[1] , -1);
mem(dp[3] , -1);
scanf("%s" , s + 1);
int len = strlen(s + 1);
for(int i = 1 ; i<= len ; i++){
num[i] = s[len - i + 1] - '0';
}
LL ans = solve(len);
printf("%lld\n",ans);
}
return 0;
}