计蒜客
Spare Tire
编辑代码
- 15.27%
- 1000ms
- 131072K
A sequence of integer \lbrace a_n \rbrace{an} can be expressed as:
\displaystyle a_n = \left\{ \begin{array}{lr} 0, & n=0\\ 2, & n=1\\ \frac{3a_{n-1}-a_{n-2}}{2}+n+1, & n>1 \end{array} \right.an=⎩⎨⎧0,2,23an−1−an−2+n+1,n=0n=1n>1
Now there are two integers nn and mm. I'm a pretty girl. I want to find all b_1,b_2,b_3\cdots b_pb1,b2,b3⋯bp that 1\leq b_i \leq n1≤bi≤n and b_ibiis relatively-prime with the integer mm. And then calculate:
\displaystyle \sum_{i=1}^{p}a_{b_i}i=1∑pabi
But I have no time to solve this problem because I am going to date my boyfriend soon. So can you help me?
Input
Input contains multiple test cases ( about 1500015000 ). Each case contains two integers nn and mm. 1\leq n,m \leq 10^81≤n,m≤108.
Output
For each test case, print the answer of my question(after mod 1,000,000,0071,000,000,007).
Hint
In the all integers from 11 to 44, 11 and 33 is relatively-prime with the integer 44. So the answer is a_1+a_3=14a1+a3=14.
样例输入复制
4 4
样例输出复制
14
题目来源
由通项公式的递推式得:an=n^2+n;
设S=1^2+2^2+.+n^2
(n+1)^3-n^3 = 3n^2+3n+1
n^3-(n-1)^3 = 3(n-1)^2+3(n-1)+1
...
..
...
2^3-1^3 = 3*1^2+3*1+1
把上面n个式子相加得:(n+1)^3-1 = 3* [1^2+2^2+...+n^2] +3*[1+2+.+n] +n
所以S= (1/3)*[(n+1)^3-1-n-(3/2)*n(n+1)] = (1/6)n(n+1)(2n+1)
n的前n项和显而易见为n*(n+1)/2;
根据容斥定理:通过状压枚举n范围内m的每个因子倍数的数量,显而易见n范围内m的因子和m不互素。我们可知m的因子y在n范围内的倍数数量为 cnt = n /y; 所以与x有关的与m不互素的和为 y*y*cnt*(cnt+1)*(2*cnt+1)+y*cnt*(cnt+1)。
由容斥的基本思想我们可以知道;我们减去单个因子的倍数,同时也会把两个因子的公倍数减去两边。所以我们还要加上两个因子公倍数在n范围内的倍数的和。
代码如下:
#include<bits/stdc++.h>
typedef long long LL;
using namespace std;
const int maxn = 1e8 + 10;
const LL mod = 1000000007;
int arr[maxn/10];
int p;
LL _pow(LL a,LL b)
{
LL res= 1;
while(b)
{
if(b&1) res=res*a%mod;
b=b>>1;
a=a*a%mod;
}
return res;
}
LL get(LL x, LL y) {
LL cnt = x / y;
return (y*y%mod*cnt%mod*(cnt+1)%mod*(2*cnt+1)%mod*_pow(6,mod-2)%mod + y*cnt%mod*(cnt+1)%mod*_pow(2,mod-2)%mod)%mod;
}
void getp(LL n) { //将要求预期互质的数因子分解
p = 0;
for(int i = 2; i * i <= n; i++) {
if(n % i == 0) {
arr[p++] = i;
while(n % i == 0)
n /= i;
}
}
if(n > 1) arr[p++] = n;
//printf("(%d)", p);
}
int main() {
LL n,m,sum;
while(scanf("%lld %lld", &n,&m) != EOF) {
LL ans = 0;
if( n == 1) {
sum=2;
printf("%lld\n", ((sum - ans)+mod)%mod);
continue;
}
getp(m);
sum =(n*(n+1)%mod*(2*n+1)%mod*_pow(6,mod-2)%mod + n*(n+1)%mod*_pow(2,mod-2)%mod)%mod;
for(int i = 1; i < (1 << p); i++) { //状压
LL res = 0, cnt = 1;
for(int j = 0; j < p; j++) {
if(i & (1 << j)) {
cnt *= arr[j];
res++;
}
}
if(res & 1){//容斥
ans += get(n, cnt);
ans%=mod;
}
else{
ans -= get(n, cnt);
ans=(ans+mod)%mod;
}
}
printf("%lld\n", ((sum - ans)+mod)%mod);
}
return 0;
}