權值線段樹 學習筆記

權值線段樹

權值線段樹是基於線段樹的一種數據結構;

線段樹維護的是區間信息,比如區間和,區間最大值等等;

而權值線段樹維護的是全局的值域信息,每個結點記錄是該結點所包含區間的值出現的次數;

權值線段樹支持:

  1. 查詢全局第k小值
  2. 查詢某個值全局的排名
  3. 查詢全局某個值的前驅
  4. 查詢全局某個值的後繼

這些操作本來是平衡樹的基本操作,但是權值線段樹也可以實現,並且代碼更加簡單;

但是一個很大的劣勢是當值域較大(10^9)時,需要離散化,也就是變成了離線算法;

這裏主要講權值線段樹的這幾個用法,具體實現過程不展開:

以下所有代碼都來做於這道題:P3369 【模板】普通平衡樹

1、建樹:

每道題的建樹都不太一樣,但是大同小異,記住這個跟普通線段樹的區別在於葉子結點記錄是這個結點所包含的區間的值出現次數;

void build(int l,int r,int k){
	tr[k].l=l,tr[k].r=r;
	if(l==r) return;
	int d=(l+r)>>1;
	build(l,d,ls);
	build(d+1,r,rs);
}

2、插入刪除

插入刪除和普通線段樹一樣,當加入某個數時,加 1 ,刪除某個數時,加 -1;

void update(int x,int w,int k){//在x點加1 
	if(tr[k].l==tr[k].r){
		tr[k].w+=w;
		return;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	if(x<=d) update(x,w,ls);
	else update(x,w,rs);
	tr[k].w=tr[ls].w+tr[rs].w;
} 

3、查詢全局第K小值

這裏要注意的點就是,當左子樹的值大於K時,說明左子樹有解,否則遞歸右子樹,但是記得減去左子樹的值;

void Xth(int x,int k){//第x小值 
	if(tr[k].l==tr[k].r){
		ans=tr[k].l;
		return;
	}
	if(tr[ls].w>=x) Xth(x,ls);
	else Xth(x-tr[ls].w,rs); 
}

4、查詢某個值x的全局排名

其實就是查找1 – x-1 這個區間和,然後再加 1 ;

void Rank(int x,int k){//x的排名,就是求1-(x-1)的區間和 
	if(tr[k].l>=1&&tr[k].r<=x){
		ans+=tr[k].w;
		return;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	if(1<=d) Rank(x,ls);
	if(x>d) Rank(x,rs);
}

5、查詢某個值x的前驅;

這個就比較複雜,其實主要就是二分思想;

這裏貼一下代碼,主要還是自己理解:

int Findp(int k){//找到這結點k區間的最右邊的數 
	if(tr[k].l==tr[k].r) return tr[k].l;
	if(tr[rs].w) return Findp(rs);
	return Findp(ls);
}
int Pre(int x,int k){//x的前驅 
	if(tr[k].r<x){
		if(tr[k].w) return Findp(k);
		return 0;
	} 
	int d=(tr[k].l+tr[k].r)>>1;
	int ans=0;
	if(d<x-1&&tr[rs].w&&(ans=Pre(x,rs))) return ans;
	return Pre(x,ls);
}

6、查詢某個值x的後繼

原理和前驅一樣;

int Findn(int k){//找到這結點k區間的最左邊的數 
	if(tr[k].l==tr[k].r) return tr[k].l;
	if(tr[ls].w) return Findn(ls);
	return Findn(rs);
}
int Nex(int x,int k){//x的後驅 
	if(tr[k].l>x){
		if(tr[k].w) return Findn(k);
		return 0;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	int ans=0;
	if(x<d&&tr[ls].w&&(ans=Nex(x,ls))) return ans;
	return Nex(x,rs);
}

最後貼一下這道題的總代碼:

#include<bits/stdc++.h>
#define ll long long
#define pa pair<int,int>
#define ls k<<1
#define rs k<<1|1
#define inf 0x3f3f3f3f
using namespace std;
const int N=100100;
const int M=10000000;
const ll mod=100000000;
struct Nod{
	int opt,x;
}a[N];
int b[N],c[N];
int f1[M*2+100],f2[N];
int n,ans;
struct Node{
	int l,r,w;
}tr[N<<2];
void build(int l,int r,int k){
	tr[k].l=l,tr[k].r=r;
	if(l==r) return;
	int d=(l+r)>>1;
	build(l,d,ls);
	build(d+1,r,rs);
}
void update(int x,int w,int k){//在x點加1 
	if(tr[k].l==tr[k].r){
		tr[k].w+=w;
		return;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	if(x<=d) update(x,w,ls);
	else update(x,w,rs);
	tr[k].w=tr[ls].w+tr[rs].w;
} 
void Xth(int x,int k){//第x小值 
	if(tr[k].l==tr[k].r){
		ans=tr[k].l;
		return;
	}
	if(tr[ls].w>=x) Xth(x,ls);
	else Xth(x-tr[ls].w,rs); 
}
void Rank(int x,int k){//x的排名,就是求1-(x-1)的區間和 
	if(tr[k].l>=1&&tr[k].r<=x){
		ans+=tr[k].w;
		return;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	if(1<=d) Rank(x,ls);
	if(x>d) Rank(x,rs);
}
int Findp(int k){//找到這結點k區間的最右邊的數 
	if(tr[k].l==tr[k].r) return tr[k].l;
	if(tr[rs].w) return Findp(rs);
	return Findp(ls);
}
int Pre(int x,int k){//x的前驅 
	if(tr[k].r<x){
		if(tr[k].w) return Findp(k);
		return 0;
	} 
	int d=(tr[k].l+tr[k].r)>>1;
	int ans=0;
	if(d<x-1&&tr[rs].w&&(ans=Pre(x,rs))) return ans;
	return Pre(x,ls);
}
int Findn(int k){//找到這結點k區間的最左邊的數 
	if(tr[k].l==tr[k].r) return tr[k].l;
	if(tr[ls].w) return Findn(ls);
	return Findn(rs);
}
int Nex(int x,int k){//x的後驅 
	if(tr[k].l>x){
		if(tr[k].w) return Findn(k);
		return 0;
	}
	int d=(tr[k].l+tr[k].r)>>1;
	int ans=0;
	if(x<d&&tr[ls].w&&(ans=Nex(x,ls))) return ans;
	return Nex(x,rs);
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n;
    int cnt=0;
    for(int i=1;i<=n;i++){
    	cin>>a[i].opt>>a[i].x;
    	if(a[i].opt==1||a[i].opt==3||a[i].opt==5||a[i].opt==6){
    		b[++cnt]=a[i].x;
    		c[cnt]=b[cnt];
		}
	}
	sort(b+1,b+1+cnt);
	int tot=cnt;
	cnt=unique(b+1,b+cnt+1)-b-1;
	for(int i=1;i<=tot;i++){//離散化 
		int po=lower_bound(b+1,b+cnt+1,c[i])-b;
		f1[c[i]+M]=po;
		f2[po]=c[i];
	}
	build(1,N,1);
	for(int i=1;i<=n;i++){
		if(a[i].opt==1) update(f1[a[i].x+M],1,1);
		else if(a[i].opt==2) update(f1[a[i].x+M],-1,1);
		else if(a[i].opt==3){
			ans=0;
			Rank(f1[a[i].x+M]-1,1);
			cout<<ans+1<<endl;
		}
		else if(a[i].opt==4){
			ans=0;
			Xth(a[i].x,1);
			cout<<f2[ans]<<endl;
		}
		else if(a[i].opt==5){
			cout<<f2[Pre(f1[a[i].x+M],1)]<<endl;
		}
		else{
			cout<<f2[Nex(f1[a[i].x+M],1)]<<endl;
		}
	}
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章