Description:
題解:
首先暴力模擬這樣的一個插入過程,不難發現每次就是找到v∈[x,y]的出現時間的最小的,然後走過去,區間變爲[x,v-1]或[v+1,y],一直到葉子節點。
先設d=gcd(b,m)
顯然的結論是,2*m/d輪以後,每次插入只會使那個點的深度加一。
之所以不是m/d輪,是因爲比如第x輪加了一個東西,剩下的可能加到它的子樹中,第x+m/d輪時,就應是它第x輪的點的右子樹的最左節點的深度+1。
如果我們能快速知道第x(x<=m/d)輪的點深度,
假設第x輪的權值是v,我們只需要討論一下v和v+d的加入時間即可計算m/d輪以後的答案。
考慮第x(x<=m/d)輪的點深度,開頭那個模擬的過程,找到[x,y]裏的最小的,縮小範圍,繼續找,記錄經過的點,實際上可以分成log段等差數列。
對於區間[x,y],先找到了v1,再找到了v2,如果v2+(v2-v1)合法,那麼下次找到的一定是這個。
這個利用反證法易得(也很顯然),這樣就類似於gcd的過程,所以是log段的。
那麼現在的目標就是找到最小的x使
a是常數,相當於平移,可以去掉,也就是
一個暴力的做法:
這是一個經典的問題,可以用標準類歐解決:
這樣總複雜度是的,TLE了。
實際上找到最小的x使可以直接類歐實現做到一個log。
設表示最小的使
總複雜度
關於實現的一點小細節,其實不用判斷v和v+d到底是誰早,可以求v的左偏父親和v+d的右偏父親的個數和即可。
Code:
#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;
int T;
int a, b, m; ll n;
ll gcd(ll x, ll y) { return !y ? x : gcd(y, x % y);}
int g(int m, int d, int l, int r) {
ll x = l / d;
if(l % d) x ++;
if(x * d <= r) return x;
if(d > m - d) return g(m, m - d, m - r, m - l);
int k = g(d, (d - m % d) % d, l % d, r % d);
x = (ll) k * m + l;
ll y = x / d; if(x % d) y ++;
return y;
}
int calc(int m, int d, int l, int r) {
int gd = gcd(m, d);
if(l == 0) return 0;
if((l - 1) / gd >= r / gd) return m + 1;
return g(m, d, l, r);
}
int calc2(int l, int r) {
if(l - a >= 0)
return calc(m, b, l - a, r - a);
if(r - a < 0)
return calc(m, b, l - a + m, r - a + m);
return min(calc(m, b, l - a + m, m - 1), calc(m, b, 0, r - a));
}
int calc3(int x, int l, int r, int z) {
int st = ((ll) calc2(l, r) * b + a) % m, ans = 0;
while(st != x) {
if(z) l = st + 1; else
r = st - 1;
if(l > r) return ans;
int t = calc2(l, r);
if(t > m) return ans;
int nt = ((ll) t * b + a) % m;
if(z) {
int y = (r - st) / (nt - st);
ans += y;
st += y * (nt - st);
} else {
int y = (st - l) / (st - nt);
ans += y;
st -= y * (st - nt);
}
}
return ans;
}
int gg(int x, int F) {
int ans = 0;
x = ((ll) b * x + a) % m;
int s = calc3(x, 0, x, 1);
if(F == 1) {
int d = gcd(m, b);
if(x + d < m) s += calc3(x + d, x + d, m - 1, 0) + 1;
} else {
s += calc3(x, x, m - 1, 0);
}
return s;
}
int main() {
freopen("fuwa.in", "r", stdin);
freopen("fuwa.out", "w", stdout);
scanf("%d", &T);
fo(ii, 1, T) {
scanf("%d %d %d %lld", &a, &b, &m, &n);
b %= m;
a = (a + b) % m; n --;
int d = gcd(b, m);
int m2 = m / d;
pp("%lld\n", gg(n % m2, n >= m2) + (n / m2));
}
}