[省選聯考 2024] 魔法手杖

退役三年選手回來做了下~

這題直觀感覺很嚇人,其實看到異或就可以往 Trie 樹上思考了。

這題有兩個未知量 \(S\)\(x\),其中 \(S\subseteq [n]\)\(x\in[0,2^k)\cap\Z\),狀態過於複雜,肯定不能枚舉,從答案的角度考慮。首先直觀感受是有點像二分,其實我們可以從高位往低位確定答案 \(ans\)

分析一下得出結論:\(ans\) 的二進制長度最長爲 \(k+1\) 位(注意這裏是個特殊情況),當且僅當 \(\sum b_i\le m\),這個時候根據貪心,取 \(x=2^k-1\) 即可,此時 \(ans=\min\{b_i\}+x\)

否則,\(ans\) 最長爲 \(k\) 位,我們可以從高往低確定 \(ans\)。假設當前正在討論第 \(d\) 位,根據貪心,我們需要判斷第 \(d\) 位是否可以爲 1。假如我們把所有的 \(a_i\) 插入一棵 Trie 樹,對於 \(val(S,x)=\min(\min \limits_{i \in S}(a_i+x),\min \limits_{i \in U \backslash S}(a_i \oplus x))\),我們稱由 \(+\) 部分和 \(\oplus\) 部分組成。如果希望結果越大,也就是兩個部分必須超過 \(ans\) 纔可以。

這個時候,我們需要先確定 \(x\) 的第 \(d\) 位的取值,我們分 \(x=0\)\(x=1\) 兩種情況。在此基礎上,\(S\) 的決策會影響 \(ans\) 的第 \(d\) 位。我們需要仔細分析一下,看一看能不能變成一個判定性問題。

假設此時在 Trie 樹上的節點 \(u\),如果左右孩子都存在,當 \(x\) 取 0 時,左子樹的所有值在 \(\oplus\) 部分會貢獻出 \(ans=0\),右子樹的所有值在 \(\oplus\) 部分會貢獻出 \(ans=1\),此時我們可以考慮將左子樹移到 \(+\) 部分,嘗試一下 \(ans\) 變成 1。考慮下影響:\(S\) 中的最小值更新了,這裏用 \(y\) 表示,那麼需要滿足兩個條件:1、左子樹的 \(b\) 的和不超過剩餘值(初始值位 \(m\));2、\(+\) 部分必須超過當前的 \(ans\)\(ans\) 未確定部分全部置成 0,\(x\) 未確定部分全部置成 1,極端情況的考慮)。只有這兩個條件滿足了,才能確定 \(ans\) 的第 \(d\) 位可以爲 1,那麼遞歸右子樹,繼續確定 \(d-1\) 位。\(x\) 取 1 時剛好對稱。

所以得到了一個 貪心+dfs 的思路。先判斷 \(ans\) 是否可以爲 1,如果可以需要遞歸兩種情況(合法時才遞歸),否則 \(ans\) 爲 0,此時 \(x\) 取 0 時右子樹可以忽略,遞歸左子樹,反之亦然。也是兩種情況。注意一下邊界(\(d=-1\) 或者 \(u\) 是空節點)。

時間複雜度 \({\rm O}(\sum nk)\),空間複雜度 \({\rm O}(nk)\)。代碼想清楚的話很好寫。

由於官方數據還沒出來,大樣例和民間數據都過了,先佔個位置。代碼有問題回頭再修改。

#include<bits/stdc++.h>
using namespace std;

void read(__int128 &x){
    // read a __int128 variable
    char c; bool f = 0;
    while (((c = getchar()) < '0' || c > '9') && c != '-');
    if (c == '-') { f = 1; c = getchar(); }
    x = c - '0';
    while ((c = getchar()) >= '0' && c <= '9') x = x * 10 + c - '0';
    if (f) x = -x;
}
void write(__int128 x){
    // print a __int128 variable
    if (x < 0) { putchar('-'); x = -x; }
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

typedef long long ll;
const int N = 1e5 + 5, K = 120 + 5;
int T, c, n, m, k, ch[N * K][2], cnt;
int b[N]; ll sumb[N * K];
__int128 a[N], mina[N * K], ans;
int nnode() {
    cnt++;
    ch[cnt][0] = ch[cnt][1] = 0;
    mina[cnt] = (__int128)1 << k;
    sumb[cnt] = 0;
    return cnt;
}
void solve(int o, int d, int t, __int128 x, __int128 y, __int128 z) {
    if (d == -1) {
        ans = max(ans, z);
        return;
    }
    __int128 bit = (__int128)1 << d, mask = bit - 1;
    if (!o) {
        ans = max(ans, y + (x | mask | bit));
        return;
    }
    int lc = ch[o][0], rc = ch[o][1];
    bool flag = false;
    if (sumb[lc] <= t && (x | mask) + min(y, mina[lc]) >= (z | bit))
        solve(rc, d - 1, t - sumb[lc], x, min(y, mina[lc]), z | bit), flag = true;
    if (sumb[rc] <= t && (x | mask | bit) + min(y, mina[rc]) >= (z | bit))
        solve(lc, d - 1, t - sumb[rc], x | bit, min(y, mina[rc]), z | bit), flag = true;
    if (flag) return;
    solve(lc, d - 1, t, x, y, z);
    solve(rc, d - 1, t, x | bit, y, z);
}
int main() {
    freopen("xor.in", "r", stdin);
    freopen("xor.out", "w", stdout);
    scanf("%d%d", &c, &T);
    while (T--) {
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 1; i <= n; i++) read(a[i]);
        for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
        cnt = 0; nnode();
        __int128 MAX = (__int128)1 << k;
        mina[0] = MAX;
        for (int i = 1; i <= n; i++) {
            int u = 1;
            sumb[u] += b[i];
            mina[u] = min(mina[u], a[i]);
            for (int j = k - 1; ~j; j--) {
                int x = a[i] >> j & 1;
                if (!ch[u][x]) ch[u][x] = nnode();
                u = ch[u][x];
                mina[u] = min(mina[u], a[i]);
                sumb[u] += b[i];
            }
        }
        ans = 0;
        if (sumb[1] <= m)
            ans = mina[1] + MAX - 1;
        else
            solve(1, k - 1, m, 0, MAX, 0);
        write(ans); putchar('\n');
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章