整体二分算法完整总结

整体二分概述

一、适用问题

整体二分,即对所有的查询进行一个整体的二分答案,需要数据结构题满足以下性质。

  1. 询问的答案具有可二分性
  2. 修改对判定答案的贡献相对独立,修改之间互不影响效果
  3. 修改如果对判定答案有贡献,则贡献为一确定的与判定标准无关的值
  4. 贡献满足交换律、结合律,具有可加性
  5. 题目允许离线操作

(来自《浅谈数据结构题的几个非经典解法》)

上面的性质看上去复杂,其实只要满足询问答案具有可二分性,且题目允许离线操作,就可以考虑一下是否可以利用整体二分算法进行求解。

二、算法介绍

上面的文字可能有些过于理论,我们现在用浅显一点的方式来理解这个算法。

假设你现在有 qq 次查询,查询区间第 kk 大的值。首先考虑如果只有 11 个查询,是否可以直接二分解决。

显然是可以的,我们只需要定位到具体区间,数一下小于等于当前二分值的数个数是否大于等于 kk 即可。于是问题就变成了如何从单次二分演变到整体二分。

我们首先维护一个操作序列,即每个点的赋值和查询,共 n+qn+q 个操作。然后实现一个 solve(l,r,L,R)solve(l,r,L,R) 函数,表示当前的操作序列在 [L,R][L,R] 范围内,而该操作序列中所有的查询操作的答案都在 [l,r][l,r] 中。

于是我们二分一个值 mid=(l+r)/2mid=(l+r)/2,然后将 [L,R][L,R] 中所有的赋值操作中数值小于等于 midmid 的数加入到对应位置,比如 a[x]mida[x]\leq mid,则 sum[x]=sum[x]+1sum[x]=sum[x]+1,对于所有的查询操作,判断其查询区间 [x,y][x,y] 的值是否大于等于 kk,如果是则将其递归到 solve(l,mid)solve(l,mid) 中,否则递归到 solve(mid+1,r)solve(mid+1,r) 中,具体内容看一下下面的例题就可以理解。

最后分析一下时间复杂度,最多分了 lognlogn 层,每一层的时间复杂度为 O(nlogn)O(nlogn),因此总时间复杂度为 O(nlog2n)O(nlog^2n)

最后附上《浅谈数据结构题的几个非经典解法》中对该算法的理论概述。

询问的答案可二分且修改对判定标准的贡献相对独立,且贡献的值与判定标准无关。因此如果我们已经计算过某一些修改对询问的贡献,那么这个贡献永远不会改变,我们没有必要当判定标准改变时再次计算这部分修改的贡献,只要记录下当前的总贡献,再进一步二分时,直接加上新的贡献即可。


整体二分系列习题

1. K-th Number

题意: 无修改的区间第 kk 大数问题。(1n105,1m5000)(1\leq n\leq 10^5,1\leq m\leq 5000)

思路: 主席树模板题,但此处我们要用整体二分的方法来解决此题。

首先我们将所有赋值操作和查询操作都放到一个数组中,形成了此题的操作序列。然后就是代码中的核心关键点 solve(l,r,L,R)solve(l,r,L,R) 函数,该函数表示区间 [L,R][L,R] 中的操作序列中的查询操作的答案一定在 [l,r][l,r] 范围内。

因此问题就变成了如何将 [L,R][L,R] 中的序列进行分组,再递归到 solve(l,mid)solve(l,mid)solve(mid+1,r)solve(mid+1,r) 中。我们只需遍历 [L,R][L,R] 中的所有操作,如果是赋值操作,则判断数值 xx 是否大于 midmid,如果小于等于 midmid,则将该操作丢到 q1q_1 数组中,并在树状数组的 xx 位置加 11;否则将操作丢到 q2q_2 数组中。

如果是查询操作,则在树状数组中查询区间 [q[i].x,q[i].y][q[i].x,q[i].y] 的值 tmptmp,如果 ktmpk\leq tmp,则将该操作丢到 q1q_1 中;否则将 k=ktmpk=k-tmp,然后丢到 q2q_2 中。

可能说起来比较复杂,但是代码比较清晰,推荐直接对代码进行理解。

代码:

#include <iostream>
#include <algorithm>
#include <cstdio>
#define mem(a,b) memset(a,b,sizeof a);
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define per(i,a,b) for(int i = a; i >= b; i--)
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
typedef long long ll;
typedef double db;
const int N = 1e5+100;
const int inf = 1e9+10;
const db EPS = 1e-9;
using namespace std;

void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}

int n,m,ans[N],c[N];
struct Node{int x,y,k,id;}q[2*N],q1[2*N],q2[2*N];
inline int lowbit(int x) {return x&(~x+1);}
inline void update(int x,int v) {for(; x<=n; x+=lowbit(x)) c[x] += v;}
inline int ask(int x){
	int res = 0;
	while(x) res += c[x], x -= lowbit(x);
	return res;
}

void solve(int l,int r,int L,int R){
	if(l > r || L > R) return;
	if(l == r){
		rep(i,L,R) if(q[i].id) ans[q[i].id] = l;
		return;
	}
	int cnt1 = 0, cnt2 = 0, mid = (l+r)>>1;
	rep(i,L,R){
		if(q[i].id){ //查询
			int tmp = ask(q[i].y)-ask(q[i].x-1);
			if(q[i].k <= tmp) q1[++cnt1] = q[i];
			else q[i].k -= tmp, q2[++cnt2] = q[i];
		}
		else{ //赋值
			if(q[i].x <= mid) update(q[i].y,1), q1[++cnt1] = q[i];
			else q2[++cnt2] = q[i];
		}
	}
	rep(i,1,cnt1) if(!q1[i].id) update(q1[i].y,-1);
	rep(i,1,cnt1) q[L+i-1] = q1[i];
	rep(i,1,cnt2) q[L+cnt1+i-1] = q2[i];
	solve(l,mid,L,L+cnt1-1); solve(mid+1,r,L+cnt1,R);
}

int main()
{
	scanf("%d%d",&n,&m);
	rep(i,1,n) {scanf("%d",&q[i].x); q[i].id = 0; q[i].y = i;}
	rep(i,1,m) {scanf("%d%d%d",&q[i+n].x,&q[i+n].y,&q[i+n].k); q[i+n].id = i;}
	solve(-inf,inf,1,n+m);
	rep(i,1,m) printf("%d\n",ans[i]);
	return 0;
}
2. Dynamic Rankings

题意: 带修改的区间第 kk 大数问题。(1n5104,1m104)(1\leq n\leq 5*10^4,1\leq m\leq 10^4)

思路: 带修改第 kk 大问题,如果要用主席树来解决的话,则需要再加上一层树状数组来维护修改信息,即用树套树解决该问题。

但是如果用整体二分来处理这个问题的话,难度则会瞬间骤降。其实此题与上题唯一的区别就是这题多了一个修改操作,而修改操作无非就是删除原来的数,加上新的数。

因此对于 a[x]=ya[x]=y 的修改操作,我们将其拆成两部分,第一部分为删除 a[x]a[x],然后令 a[x]=ya[x]=y,第二部分是加上 a[x]a[x],具体细节可以参考下面的代码实现。

代码:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#define mem(a,b) memset(a,b,sizeof a);
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define per(i,a,b) for(int i = a; i >= b; i--)
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
typedef long long ll;
typedef double db;
const int N = 1e5+100;
const int inf = 1e9+10;
const db EPS = 1e-9;
using namespace std;

void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}

int n,m,ans[N],c[N],a[N];
struct Node{int x,y,k,id;}q[2*N],q1[2*N],q2[2*N];
inline int lowbit(int x) {return x&(~x+1);}
inline void update(int x,int v) {for(; x<=n; x+=lowbit(x)) c[x] += v;}
inline int ask(int x){
	int res = 0;
	while(x) res += c[x], x -= lowbit(x);
	return res;
}

void solve(int l,int r,int L,int R){
	if(l > r || L > R) return;
	if(l == r){
		rep(i,L,R) if(q[i].k) ans[q[i].id] = l;
		return;
	}
	int cnt1 = 0, cnt2 = 0, mid = (l+r)>>1;
	rep(i,L,R){
		if(q[i].k){ //查询
			int tmp = ask(q[i].y)-ask(q[i].x-1);
			if(q[i].k <= tmp) q1[++cnt1] = q[i];
			else q[i].k -= tmp, q2[++cnt2] = q[i];
		}
		else{ //赋值
			if(q[i].x <= mid) update(q[i].id,q[i].y), q1[++cnt1] = q[i];
			else q2[++cnt2] = q[i];
		}
	}
	rep(i,1,cnt1) if(!q1[i].k) update(q1[i].id,-q1[i].y);
	rep(i,1,cnt1) q[L+i-1] = q1[i];
	rep(i,1,cnt2) q[L+cnt1+i-1] = q2[i];
	solve(l,mid,L,L+cnt1-1); solve(mid+1,r,L+cnt1,R);
}

int main()
{
	int _; scanf("%d",&_);
	while(_--){
		scanf("%d%d",&n,&m);
		int cnt = 0, tot = 0;
		memset(c,0,sizeof c);
		rep(i,1,n) {
			scanf("%d",&a[i]);
			q[++cnt] = {a[i],1,0,i};
		}
		rep(i,1,m){
			char op[5]; scanf("%s",op);
			int x,y,k;
			if(op[0] == 'Q'){
				scanf("%d%d%d",&x,&y,&k);
				q[++cnt] = {x,y,k,++tot};
			}
			else{
				scanf("%d%d",&x,&y);
				q[++cnt] = {a[x],-1,0,x};
				a[x] = y;
				q[++cnt] = {a[x],1,0,x};
			}
		}
		solve(-inf,inf,1,cnt);
		rep(i,1,tot) printf("%d\n",ans[i]);
	}
	return 0;
}
3. K大数查询

题意: nn 个位置,mm 个操作。操作有两种,1 a b c1\ a\ b\ c 表示在第 aa 个位置到第 bb 个位置,每个位置加入一个数 cc2 a b c2\ a\ b\ c 表示询问从第 aa 个位置到第 bb 个位置,第 cc 大的数是多少。(1n,m5104)(1\leq n,m\leq 5*10^4)

思路: 其实和上面第二个问题没有太大的差别,只不过上一个问题是单点修改,而这题变成了区间修改。因此我们用线段树维护一下整体二分即可解决。

代码:

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#define mem(a,b) memset(a,b,sizeof a);
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define per(i,a,b) for(int i = a; i >= b; i--)
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
typedef long long ll;
typedef double db;
const int N = 1e5+100;
const int inf = 1e9+10;
const db EPS = 1e-9;
using namespace std;

void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}

int n,m;
ll ans[N],sum[2*N],lazy[2*N];
struct Node{ll x,y,k,id;}q[2*N],q1[2*N],q2[2*N];

inline int get_id(int l,int r) {return (l+r)|(l!=r);}
inline void pushDown(int l,int r){
	int mid = (l+r)>>1, now = get_id(l,r), ls = get_id(l,mid), rs = get_id(mid+1,r);
	sum[ls] += lazy[now]*(ll)(mid-l+1); sum[rs] += lazy[now]*(ll)(r-mid);
	lazy[ls] += lazy[now]; lazy[rs] += lazy[now];
	lazy[now] = 0;
}
inline void update(int l,int r,int L,int R,int v){
	int now = get_id(l,r);
	if(L <= l && r <= R){
		sum[now] += (ll)v*(ll)(r-l+1);
		lazy[now] += v;
		return;
	}
	if(lazy[now]) pushDown(l,r);
	int mid = (l+r)>>1;
	if(L <= mid) update(l,mid,L,R,v);
	if(R > mid) update(mid+1,r,L,R,v);
	sum[now] = sum[get_id(l,mid)]+sum[get_id(mid+1,r)];
}
inline ll query(int l,int r,int L,int R){
	int now = get_id(l,r);
	if(L <= l && r <= R) return sum[now];
	if(lazy[now]) pushDown(l,r);
	int mid = (l+r)>>1;
	ll thp = 0;
	if(L <= mid) thp += query(l,mid,L,R);
	if(R > mid) thp += query(mid+1,r,L,R);
	return thp; 
}

void solve(int l,int r,int L,int R){
	if(l > r || L > R) return;
	if(l == r){
		rep(i,L,R) if(q[i].id) ans[q[i].id] = l;
		return;
	}
	int cnt1 = 0, cnt2 = 0, mid = (l+r)>>1;
	rep(i,L,R){
		if(q[i].id){ //查询
			ll tmp = query(1,n,q[i].x,q[i].y);
			if(q[i].k <= tmp) q2[++cnt2] = q[i];
			else q[i].k -= tmp, q1[++cnt1] = q[i];
		}
		else{ //赋值
			//由于右区间的起始点为mid+1, 因此此处为 >= mid+1
			if(q[i].k >= mid+1) update(1,n,q[i].x,q[i].y,1), q2[++cnt2] = q[i];
			else q1[++cnt1] = q[i];
		}
	}
	rep(i,1,cnt2) if(!q2[i].id) update(1,n,q2[i].x,q2[i].y,-1);
	rep(i,1,cnt1) q[L+i-1] = q1[i];
	rep(i,1,cnt2) q[L+cnt1+i-1] = q2[i];
	solve(l,mid,L,L+cnt1-1); solve(mid+1,r,L+cnt1,R);
}

int main()
{
	scanf("%d%d",&n,&m);
	int tot = 0;
	rep(i,1,m){
		ll op,x,y,k;
		scanf("%lld%lld%lld%lld",&op,&x,&y,&k);
		if(op == 1) q[i] = {x,y,k,0};
		else q[i] = {x,y,k,++tot};
	}
	solve(-inf,inf,1,m);
	rep(i,1,tot) printf("%lld\n",ans[i]);
	return 0;
}
4. Stamp Rally

题意: 一个 nn 个点,mm 条边的图,第 ii 条边连接 aia_ibib_i,保证图是连通的。

现在有 qq 次询问,每次询问给出一个三元组 x y zx\ y\ z,表示询问从 x yx\ y 两个点出发,一共扩展 zz 个不同的点(包括起始点),求所经过的边中最大编号的最小值。(3n105,1q105)(3\leq n\leq 10^5,1\leq q\leq 10^5)

思路: 此题较之上面三题,没有那么套路,因此我们先从只有一个询问开始找思路。

首先考虑能不能把图变成树,因为图上问题往往都很复杂,而变成树上问题后我们的可操作空间会大很多。继续思考不难发现,如果我们按边的编号为权值构建一棵最小生成树,每次询问的答案也一定会落在最小生成树上的边上。

转到树上问题之后,我们考虑能不能二分答案然后 checkcheck,如果只有一个询问的话,显然是可以的。只需要维护一个可加边可删边的按秩合并的并查集即可。

既然单个查询可以二分,那一定可以用整体二分的方法对所有查询进行二分。我们在二分值为 midmid 时,将所有编号小于等于 midmid 的边连接起来,然后再递归到 [mid+1,r][mid+1,r] 区间。右区间递归结束后,再撤销二分值为 midmid 时连接的边,然后递归 [l,mid][l,mid]

如果连接操作是 fa[x]=y,sz[y]=sz[y]+sz[x]fa[x]=y,sz[y]=sz[y]+sz[x],那么撤销操作的时候,如果只有 fa[x]=x,sz[y]=sz[y]sz[x]fa[x]=x,sz[y]=sz[y]-sz[x] 是不够的,需要从 yy 开始不断向上访问,对于访问到的每一个节点都减去 sz[x]sz[x],如此才能保证撤销操作的正确性。

代码:

#include <bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a);
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define per(i,a,b) for(int i = a; i >= b; i--)
#define __ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
typedef long long ll;
typedef double db;
const db EPS = 1e-9;
const int N = 1e6+100;
using namespace std;

void dbg() {cout << "\n";}
template<typename T, typename... A> void dbg(T a, A... x) {cout << a << ' '; dbg(x...);}
#define logs(x...) {cout << #x << " -> "; dbg(x);}

int n,m,Q,fa[N],sz[N],ans[N];
struct Node {int x,y,k,id,h1,h2;} q[N],q1[N],q2[N];

int find(int x) {return x == fa[x] ? x : find(fa[x]);}
int calc(int x,int y){
	int fx = find(x), fy = find(y);
	return fx == fy ? sz[fx] : (sz[fx] + sz[fy]);
}
pair<int,int> merge(int x,int y){
	int fx = find(x), fy = find(y);
	if(fx == fy) return make_pair(-1,-1);
	if(sz[fx] < sz[fy]){
		fa[fx] = fy, sz[fy] += sz[fx];
		return make_pair(fx,fy);
	}
	else{
		fa[fy] = fx, sz[fx] += sz[fy];
		return make_pair(fy,fx);
	}
}
void Delete(int x,int y) {
	fa[x] = x;
	while(y){
		sz[y] -= sz[x];
		if(y == fa[y]) break;
		y = fa[y];
	}
}

void solve(int l,int r,int L,int R){
	if(l > r || L > R) return;
	if(l == r){
		rep(i,L,R) if(q[i].k) ans[q[i].id] = l;
		return;
	}
	int mid = (l+r)>>1, cnt1 = 0, cnt2 = 0;
	// logs(mid,L,R);
	rep(i,L,R){
		if(q[i].k){ //查询
			int tmp = calc(q[i].x,q[i].y);
			if(q[i].k <= tmp) q1[++cnt1] = q[i];
			else q2[++cnt2] = q[i];
		}
		else{
			if(q[i].id <= mid){
				pair<int,int> tmp = merge(q[i].x,q[i].y);
				q1[++cnt1] = q[i];
				q1[cnt1].h1 = tmp.first; q1[cnt1].h2 = tmp.second;
			}
			else q2[++cnt2] = q[i];
		}
	}
	rep(i,1,cnt1) q[L+i-1] = q1[i];
	rep(i,1,cnt2) q[L+cnt1+i-1] = q2[i];
	solve(mid+1,r,L+cnt1,R);
	rep(i,1,cnt1) if(!q[L+i-1].k) Delete(q[L+i-1].h1,q[L+i-1].h2);
	solve(l,mid,L,L+cnt1-1);
}

int main()
{
	int cnt = 0;
	scanf("%d%d",&n,&m);
	rep(i,1,n) fa[i] = i, sz[i] = 1;
	rep(i,1,m){
		int x,y; scanf("%d%d",&x,&y);
		if(merge(x,y).first == -1) continue;
		q[++cnt] = {x,y,0,i,0,0};
		// logs(i);
	}
	scanf("%d",&Q);
	rep(i,1,Q){
		int x,y,z; scanf("%d%d%d",&x,&y,&z);
		q[++cnt] = {x,y,z,i,0,0};
	}
	rep(i,1,n) fa[i] = i, sz[i] = 1;
	solve(0,m+1,1,cnt);
	rep(i,1,Q) printf("%d\n",ans[i]);
	return 0;
}

后记

整体二分的内容到这里就结束了,总的来说,该算法应该属于一种解题套路,所需要的学习时间不长,可以当作解决数据结构问题的一种套路。最后祝大家 ACAC 愉快,一起爱上二分把!(๑•̀ㅂ•́)و✧

ACM 的旅行虽然充满荆棘但一擡头便能看见无数束光,请务必坚持下去,负重前行终有云开雾散之日!💪💪💪

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章