问题描述:
Bob has a favorite number k and ai of length n. Now he asks you to answer m queries. Each query is given by a pair li and ri and asks you to count the number of pairs of integers i and j, such that l ≤ i ≤ j ≤ r and the xor of the numbers ai, ai + 1, ..., aj is equal to k.
Input
The first line of the input contains integers n, m and k (1 ≤ n, m ≤ 100 000, 0 ≤ k ≤ 1 000 000) — the length of the array, the number of queries and Bob's favorite number respectively.
The second line contains n integers ai (0 ≤ ai ≤ 1 000 000) — Bob's array.
Then m lines follow. The i-th line contains integers li and ri (1 ≤ li ≤ ri ≤ n) — the parameters of the i-th query.
Output
Print m lines, answer the queries in the order they appear in the input.
解题思路:
前缀异或和:[L,R] = [0, L-1] ^ [0, R]。
在这个前提下,如果我们已知[L,R]内异或和为k的区间数量,那么我们可以在O(1)时间内推出L或R左移右移一格后,异或和为k的区间的数量。
以R右移一格变成R+1为例,此时增加的区间的右边界一定是R+1;这时有[l, R+1] = [0,l-1]^[0,R+1] = k (l [L, R]), 即[0,l-1] = k^[0,R+1], 所以增加的数量是L-1~ R中前缀异或和为k的前缀异或区间的数量。因此,只要我们保存了L-1到R的前缀异或和的数值和数量,就可以在O(1)时间内推出右边界右移一格时异或和为k的区间数量变化。
这时候我们就可以使用莫队算法了。
莫队算法是对区间进行分块,然后以块号顺序对区间进行查询。通常是以L/sqrt(n)作为块号,同一个块内按R排序。我们可以计算一下一个块内的时间复杂度。因为同块内按R排序,R只会往右移,时间复杂度为n。而同一个块内L的差值在sqrt(n)内,所以L会在sqrt(n)的范围内震荡。假设块内区间数为m1,则最坏时间复杂度为O(m1*sqrt(n)+n)。考虑所有区间,总的时间复杂度为O(m*sqrt(n) + n*sqrt(n)), 即在这里是n^(1.5)。
题解:
// 莫队算法
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const int MAX_N = 100010;
const int MAX_NUM = 1000010;
int xor_sum[MAX_N]; // 前缀异或和
int counts[2*MAX_NUM]; // 前缀异或和的值对应的数量
long long results[MAX_N]; // 查询的解
struct Query
{
int L, R, id, block;
}query[MAX_N];
// 先按块排序,再按右边界排序
bool operator<(Query& a, Query& b)
{
if (a.block == b.block) return a.R < b.R;
return a.L < b.L;
}
int main()
{
int n, m, k;
cin >> n >> m >> k;
memset(xor_sum, 0, sizeof(xor_sum));
memset(counts, 0, sizeof(counts));
int num;
for (int i = 1; i <= n; i++)
{
scanf("%d", &num);
xor_sum[i] = xor_sum[i-1] ^ num;
}
int s = sqrt(n);
for (int i = 0; i < m; i++)
{
scanf("%d%d", &query[i].L, &query[i].R);
query[i].id = i;
query[i].block = query[i].L / s;
}
sort(query, query+m);
int l = 1, r = 0;
counts[xor_sum[l-1]]++;
long long result = 0;
// [0,L-1] ^ [0,R] = [L,R] 所以要记录count([0,L-1] ~ [0,R])
for (int i = 0; i < m; i++)
{
// 右边界右移
while(r < query[i].R)
{
r++;
result += counts[k^xor_sum[r]];
counts[xor_sum[r]]++;
}
// 右边界左移
while (r > query[i].R)
{
counts[xor_sum[r]]--;
result -= counts[k^xor_sum[r]];
r--;
}
// 左边界右移
while (l < query[i].L)
{
counts[xor_sum[l-1]]--;
result -= counts[k^xor_sum[l-1]];
l++;
}
// 左边界左移
while (l > query[i].L)
{
l--;
result += counts[k^xor_sum[l-1]];
counts[xor_sum[l-1]]++;
}
results[query[i].id] = result;
}
for (int i = 0; i < m; i++)
printf("%I64d\n", results[i]);
}