【2019 BAPC - D】Deck Randomisation【中國剩餘定理 + 循環節】

題意

洗牌遊戲。初始順序爲 1n1~nAliceAliceBobBob 輪流操作此牌堆。AliceAlice 操作時會將位置 ii 上的牌移到 a[i]a[i] 上;BobBob 操作時會將位置 ii 上的牌移到 b[i]b[i] 上,問最少移動多少次桌上的牌會恢復原狀,若 101210^{12} 次內無法還原,則輸出 hugehuge(1n105)(1\leq n\leq 10^5)


思路

首先一個很明顯的思路,對結束點的分類枚舉,即該遊戲可能結束於 AliceAlice,也可能結束於 BobBob

若結束於 BobBob,則可以將 AliceAliceBobBob 的兩個操作合成一個操作,然後問最少多少步可以還原。此時答案比較好求,即先求出遊戲中的多個循環節,各個循環節長度的最小公倍數即爲答案。

若結束於 AliceAlice,則修改每一個點的結束位置,因此對於每個循環節來說,最少操作次數應爲 kT[i]+mod[i]k*T[i]+mod[i],其中 kk 爲任意非負整數,T[i]T[i] 爲該循環節的長度,mod[i]mod[i] 表示該循環節中的節點距離其目標節點的最短距離。

因此我們可以發現最終的答案需要求解一個同餘方程組,因此使用中國剩餘定理進行求解。

比賽時思路就已經到達了此步,但仍然無論如何都無法 ACAC,主要有兩個原因:

  1. 中國剩餘定理板子不夠優秀(大數據無法處理)
  2. 計算過程爆 longlong longlong 了,用 __int128\_\_int128 也不行,最後改寫成 pythonpython 後通過

反思

一道題懟的太久,導致錯失了 ACAC 其他題的可能性。除此之外,數論方面的板子過於陳舊,今後遇到數論題需要主動承擔不能逃避。最後,中國剩餘定理存在爆 longlong longlong 的可能,需要及時反應過來並將其換成 pythonpythonjavajava


代碼

cppcpp 代碼,最後兩個測試點無法通過。

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a);
#define rep(i,a,b) for(ll i = a; i <= b; i++)
#define per(i,a,b) for(ll i = a; i >= b; i--)
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
#define ll __int128
typedef double db;
const db EPS = 1e-9;
const int N = 1e5+100;
const ll inf = 1e17;
using namespace std;

void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}

// tot 個循環結,T[i] 每個循環節長度
// vis[i] 第 i 個點所在循環節位置
ll n,a[N],b[N],c[N],tot,vis[N];
ll T[N], mod[N], m[N];

void init(){
	rep(i,1,n) c[i] = b[a[i]];
	rep(i,1,n){
		if(vis[i]) continue;
		ll tmp = c[i], len = 1;
		while(tmp != i){
			tmp = c[tmp];
			len++;
		}
		T[++tot] = len;
		vis[i] = tot; tmp = c[i];
		while(tmp != i){
			vis[tmp] = tot;
			tmp = c[tmp];
		}
	}
}

ll gcd(ll a, ll b){
	return b == 0 ? a : gcd(b, a%b);
}

ll solve1(){
	ll ans = T[1];
	rep(i,2,tot){
		ans = ans * T[i] / gcd(ans, T[i]);
		if(ans > 1e12) return 1e13;
	}
	return ans * 2ll;
}

ll exgcd(ll a, ll b, ll &x, ll &y) {
    if (b == 0) {x = 1, y = 0; return a;}
    ll r = exgcd(b, a % b, x, y), tmp;
    tmp = x; x = y; y = tmp - (a / b) * y;
    return r;
}
ll inv(ll a, ll b) {
    ll x = 0, y = 0;
    ll r = exgcd(a, b, x, y);
    while (x < 0) x += b;
    return x;
}
ll excrt(ll n, ll M[], ll C[]){ // % m = c
    for (ll i = 2; i <= n; i++) {
        ll M1 = M[i - 1], M2 = M[i], C2 = C[i], C1 = C[i - 1], T = gcd(M1, M2);
        if ((C2 - C1) % T != 0) return 1e13;
        M[i] = (M1 * M2) / T;
        C[i] = ( inv( M1 / T , M2 / T ) * (C2 - C1) / T ) % (M2 / T) * M1 + C1;
        C[i] = (C[i] % M[i] + M[i]) % M[i];
        if(C[i] > 1e12) return 1e13;
    }
    return C[n];
}

ll tt[N];

ll solve2(){
	rep(i,1,n) tt[i] = a[i];
	rep(i,1,n) a[tt[i]] = i;
	rep(i,1,n) m[i] = inf;
	rep(i,1,n){
		if(m[i] != inf) continue;
		if(vis[a[i]] != vis[i]) return 1e13;
		else{
			ll len = 0, tmp = i;
			while(tmp != a[i]){
				tmp = c[tmp];
				len++;
			}
			m[i] = len;
			ll p1 = c[i], p2 = c[a[i]];
			while(p1 != i){
				if(a[p1] != p2) return 1e13;
				else m[p1] = len;
				p1 = c[p1];
				p2 = c[p2];
			}
		}
	}
	rep(i,1,n) mod[i] = inf;
	rep(i,1,n)
		if(m[i] == T[vis[i]]) m[i] = 0;
	rep(i,1,n){
		if(mod[vis[i]] == inf) mod[vis[i]] = m[i];
		else if(mod[vis[i]] != m[i]) return 1e13;
	}
	return (excrt(tot,T,mod) * 2ll + 1ll);
}

inline __int128 read() {
    __int128 x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9') {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9') {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x*f;
}

inline void write(__int128 x) {
    if(x<0) { putchar('-'); x=-x; }
    if(x>9) write(x/10);
    putchar(x%10+'0');
}

int main()
{
	n = read();
	rep(i,1,n) a[i] = read();
	rep(i,1,n) b[i] = read();
	ll ans = 1e13;
	init();
	ll tp = solve1();
	if(tp > 0) ans = min(ans,tp);
	tp = solve2();
	if(tp > 0) ans = min(ans,tp);
	if(ans > 1e12) printf("huge");
	else write(ans);
	return 0;
}

改寫成 pythonpython 代碼後,可以通過全部測試點。

inf, N = int(1e17), int(100100)
# tot 個循環結,T[i] 每個循環節長度
# vis[i] 第 i 個點所在循環節位置
n, a, b, c, tot, vis = 0, [0], [0], [0]*N, 0, [0]*N
T, mod, m, tt, tmp, len = [0]*N, [0]*N, [0]*N, [0]*N, 0, 0


def init():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N, tmp, len
    for i in range(1, n+1):
        c[i] = b[a[i]]
    for i in range(1, n+1):
        if vis[i] != 0:
            continue
        tmp, len = c[i], 1
        while tmp != i:
            tmp = c[tmp]
            len += 1
        tot += 1
        T[tot] = len
        vis[i] = tot
        tmp = c[i]
        while tmp != i:
            vis[tmp] = tot
            tmp = c[tmp]


def gcd(a, b):
    if b == 0:
        return a
    else:
        return gcd(b, a % b)


def solve1():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N
    ans = 0
    ans = T[1]
    for i in range(2, tot+1):
        ans = ans * T[i] / gcd(ans, T[i])
        if ans > 1e12:
            return int(1e13)
    return ans * int(2)


def exgcd(a, b, x, y):
    re, tmp = 0, 0
    if b == 0:
        return a, 1, 0
    re, x, y = exgcd(b, a % b, x, y)
    tmp = x
    x = y
    y = tmp - (a//b)*y
    return re, x, y


def inv(a, b):
    x, y, r = 0, 0, 0
    r, x, y = exgcd(a, b, x, y)
    while x < 0:
        x += b
    return x


def excrt(n, M, C):
    M1, M2, C1, C2, T = 0, 0, 0, 0, 0
    for i in range(2, n+1):
        M1, M2, C1, C2 = M[i-1], M[i], C[i-1], C[i]
        T = gcd(M1, M2)
        if (C2-C1) % T != 0:
            return 1e13
        M[i] = (M1 * M2) // T
        C[i] = (inv(M1//T, M2//T) * (C2-C1) // T) % (M2 // T) * M1 + C1
        C[i] = (mod[i] % M[i] + M[i]) % M[i]
        if C[i] > 1e12:
            return 1e13
    return C[n]


def solve2():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N, tmp, len
    for i in range(1, n+1):
        tt[i] = a[i]
    for i in range(1, n+1):
        a[tt[i]] = i
    for i in range(1, n+1):
        m[i] = inf
    for i in range(1, n+1):
        if m[i] != inf:
            continue
        if vis[a[i]] != vis[i]:
            return int(1e13)
        else:
            tmp = 0
            len, tmp, p1, p2 = 0, i, 0, 0
            while tmp != a[i]:
                tmp = c[tmp]
                len += 1
            m[i] = len
            p1, p2 = c[i], c[a[i]]
            while p1 != i:
                if a[p1] != p2:
                    return int(1e13)
                else:
                    m[p1] = len
                p1 = c[p1]
                p2 = c[p2]
    for i in range(1, n+1):
        mod[i] = inf
    for i in range(1, n+1):
        if m[i] == T[vis[i]]:
            m[i] = 0
    for i in range(1, n+1):
        if mod[vis[i]] == inf:
            mod[vis[i]] = m[i]
        elif mod[vis[i]] != m[i]:
            return int(1e13)
    return excrt(tot, T, mod) * 2 + 1


def main():
    global n, a, b, c, tot, vis, T, mod, m, tt, inf, N
    n = int(input())
    a.extend(list(map(int, input().split())))
    b.extend(list(map(int, input().split())))
    ans = int(1e13)
    init()
    ans = min(ans, solve1())
    ans = min(ans, solve2())
    if ans > 1e12:
        print("huge")
    else:
        print(int(ans))


if __name__ == "__main__":
    main()

後記

場上自閉 2h2h,場下自閉 5h5h,真實自閉題…

不過不得不說,好久沒這麼自閉調題了,ACAC 的那一刻的確是直接跳起來了,久違的感覺!

最後,請務必繼續奔跑,追逐屬於自己的那束光!💪💪💪

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章