HDU 6059 Kanade's trio 字典樹 統計 容斥

鏈接

http://acm.hdu.edu.cn/showproblem.php?pid=6059

題意

給出A[1..n](1<=n<=5105)(0<=A[i]<230) ,要求統計三元組(i,j,k) 的個數使其滿足i<j<k 並且(A[i]xorA[j])<(A[j]xorA[k])

思路

事先把所有數字插入字典樹中,用字典樹維護A[k] 的信息,接着對每一個A[i] ,枚舉其二進制最高位小於A[k] 的位數,考慮這樣一個情況:若當前枚舉到了A[i] 的二進制第5位比A[k] 小,那麼A[i]A[k] 的第30位到第6位都是相同的,此時就不用考慮A[j] 的第30位到第6位如何,只考慮第5位的情況就好。
考慮當前位置A[i] 的情況,若當前位置A[i] 爲0,那麼A[j] 的相同位置要爲0才能使兩者異或值爲0,此時A[k] 爲1,這時滿足條件的A[j]A[k] 對數可以計入答案。當前位置A[i] 爲1的情況同理(A[j] 爲1,A[k] 爲0)。

對於A[j]A[k] 對數的統計,在插入A[k] 時,之前插入的數就都成了A[j] 。因此用一個cnt[i][j]數組記錄下第i位爲j的數之前出現了幾次,那麼在插入時,對於這一位置A[k] 爲0的情況,之前有多少的A[j] 在這一位爲1,就是此時滿足條件的A[j] 的個數。代碼裏的cnt[i][nxt ^ 1]就是此時符合條件的A[j] 個數。

當我們把一個數從字典樹中去掉時,也要考慮去掉這個數留下來的統計值。
這題特殊的地方在於,插入是連續的,之後是連續的刪除,所以在插入完成後可以把cnt[i][j]數組清空一遍,用來記錄第i位爲j的數被刪除了幾次。
考慮兩個方面:
- 一個是這個數作爲A[k] 直接被去掉帶來的影響,像之前一樣減去其前面已經被刪去的A[j] 的個數(依然是cnt[now][nxt ^ 1])就好。(這一步操作在Trie::Insert()裏面,與插入時的操作類似)
- 還有一個是這個數作爲A[j] 帶來的影響,因爲這個數已經不能和後面的A[k] 組合產生貢獻了,考慮到在統計時,當前位的A[k] 已經把可以與其組合的A[j] 個數統計在了sum[tmp]中,這裏面還需去掉被刪去的A[j] ,被刪去的A[j] 已經被統計在了cnt[i][nxt]中,現有的A[k] 被存在了val[tmp]中,這一部分不能被計入答案,相乘,減去。(sum[tmp] - val[tmp] * cnt[i][nxt]這一步在函數solve()裏)

希望思路說清楚了,詳見代碼

代碼

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

#define MS(x, y) memset(x, y, sizeof(x))

typedef long long LL;
const int MAXN = 5e5 + 5;
int bits[32];

struct Trie {
  int tot, root;
  int val[MAXN * 30], ch[MAXN * 30][2];
  LL sum[MAXN * 30], cnt[MAXN][2];

  int newnode() {
    val[tot] = sum[tot] = 0;
    ch[tot][0] = ch[tot][1] = -1;
    return tot++;
  }

  void init() {
    tot = 0;
    root = newnode();
    MS(cnt, 0);
  }

  void Insert(int x, int v) {
    int now = root, nxt, tmp;
    for (int i = 30; i >= 0; --i) {
      nxt = !!(x & bits[i]);
      if (ch[now][nxt] == -1) ch[now][nxt] = newnode();
      now = ch[now][nxt];
      ++cnt[i][nxt];
      sum[now] += v * cnt[i][nxt ^ 1];
      val[now] += v;
    }
  }

  LL solve(int x) {
    LL ret = 0;
    int now = root, tmp, nxt;
    for (int i = 30; i >= 0; --i) {
      nxt = !!(x & bits[i]);
      tmp = ch[now][nxt ^ 1];
      now = ch[now][nxt];
      if (tmp != -1) {
        ret += sum[tmp] - val[tmp] * cnt[i][nxt];
      }
      if (now == -1) break;
    }
    return ret;
  }
};

int n;
int a[MAXN];
LL ans;
Trie trie;

int main() {
  bits[0] = 1;
  for (int i = 1; i < 32; ++i) bits[i] = bits[i - 1] << 1;
  int T;
  scanf("%d", &T);
  while (T--) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    ans = 0;
    trie.init();
    for (int i = 1; i <= n; ++i) trie.Insert(a[i], 1);
    MS(trie.cnt, 0);
    for (int i = 1; i < n; ++i) {
      trie.Insert(a[i], -1);
      ans += trie.solve(a[i]);
    }
    printf("%I64d\n", ans);
  }
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章