平衡樹入門——Splay
一種自帶大常數的平衡樹,但是 LCT 要用到它,所以今天學了一下。
1 簡介
伸展樹(Splay Tree),也叫分裂樹,是一種二叉排序樹,它能在 \(O(\log n)\) 內完成插入、查找和刪除操作。它由丹尼爾·斯立特 Daniel Sleator 和 羅伯特·恩卓·塔揚 Robert Endre Tarjan 在1985年發明的。
在伸展樹上的一般操作都基於伸展操作:假設想要對一個二叉查找樹執行一系列的查找操作,爲了使整個查找時間更小,被查頻率高的那些條目就應當經常處於靠近樹根的位置。於是想到設計一個簡單方法, 在每次查找之後對樹進行重構,把被查找的條目搬移到離樹根近一些的地方。伸展樹應運而生。伸展樹是一種自調整形式的二叉查找樹,它會沿着從某個節點到樹根之間的路徑,通過一系列的旋轉把這個節點搬移到樹根去。
它的優勢在於不需要記錄用於平衡樹的冗餘信息。
2 數據結構剖析
Splay 的左旋和右旋操作和 Treap 是一樣的,不過因爲 Splay 並沒有記錄其餘信息,所以它旋轉的“理由”和 Treap 不一樣,但其實都是爲了均攤複雜度。多次旋轉只是一個使用概率的問題。
2.1 其他函數與結構體
結構體一共打包了這些東西:
struct point{
int size,cnt,val,ch[2],fa;
};
point p[N];
其中 \(size\) 指的是子樹大小,\(cnt\) 是這個權值的出現次數,\(val\) 是這個節點代表的權值,\(ch_0,ch_1\) 分別是左右節點,\(fa\) 值得是父親節點。
這些其他函數包括:
pushup
合併信息的一個函數
inline void pushup(int k){
p[k].size=p[p[k].ch[0]].size+p[p[k].ch[1]].size+p[k].cnt;
}
clear
清理節點
inline void clear(int k){
p[k].size=p[k].ch[0]=p[k].cnt=p[k].ch[1]=p[k].val=p[k].fa=0;
}
get
判斷這個節點是其父親的左兒子還是右兒子
inline int get(int k){
return k==p[p[k].fa].ch[1];
}
new_node
添加一個新節點
inline void new_node(int val){
++tot;p[tot].val=val;p[tot].size=1;p[tot].cnt=1;
}
2.2 旋轉
還是和 Treap 一樣的左旋和右旋,不過 Splay 的左旋右旋需要分兩種情況討論:
-
當父親是根節點的時候(圖 \(1,2\)),直接旋轉就可以了。
-
當父親和爺爺在一條直線上的時候,要先旋轉父親再旋轉兒子。(圖 \(2,3\) )
-
如果不在一條直線上,只旋轉兒子就可以。
至於情況 \(2\) 爲什麼要先旋轉父親,在旋轉兒子,這裏掛上勢能分析法的博客,可以證明,如果不這樣旋轉,實際上樹的深度是不能減小的,可以卡成 \(O(n^2)\) ,而這樣旋轉的複雜度均攤下來是 \(O(n\log n)\) 的。
在下面的代碼中,左旋和右旋寫到了一個函數裏。
inline void rotate(int k){
int y=p[k].fa,z=p[y].fa,which=get(k),which2=get(y);
p[y].ch[which]=p[k].ch[which^1];
if(p[y].ch[which]) p[p[y].ch[which]].fa=y;
p[k].ch[which^1]=y;
if(z) p[z].ch[which2]=k;
p[y].fa=k;p[k].fa=z;
pushup(y);pushup(k);
}
2.3 Splay 操作
這個操作指的是把某一個節點按照上面的旋轉規則旋轉到根節點。在伸展樹中的所有除了刪除的其他操作都需要 Splay 一下,原因是一個使用概率的問題,被訪問過的元素使用概率會比較高。Splay 操作之後不要忘記更新根節點!
inline void splay(int k){
for(int fa=p[k].fa;fa=p[k].fa,fa;rotate(k)){
if(p[fa].fa) rotate(get(fa)==get(k)?fa:k);
}
root=k;
}
2.4 查詢排名&查詢權值
基本思想和 Treap 大致相同,只是最後需要 Splay 一下,代碼也寫成非遞歸版了。
inline int getrank(int val){
int rank=0,k=root;
while(k)
if(val<p[k].val) k=p[k].ch[0];
else{
rank+=p[p[k].ch[0]].size;
if(val==p[k].val){
splay(k);return rank+1;
}
rank+=p[k].cnt;k=p[k].ch[1];
}
return INF;
}
inline int getval(int rank){
int k=root;
while(k)
if(rank<=p[p[k].ch[0]].size) k=p[k].ch[0];
else{
rank-=p[p[k].ch[0]].size+p[k].cnt;
if(rank<=0){
splay(k);return p[k].val;
}
k=p[k].ch[1];
}
return INF;
}
2.5 查詢前驅後繼
考慮到因爲可能這個節點在樹種就不存在,所以我們先插入這個節點,然後找前驅後繼就非常方便——因爲 Splay 操作,這個節點已經是根節點,前驅就是左子樹中最右邊的節點,後繼就是右子樹中最左邊的節點,直接找就可以。最後我們把這個節點刪除,刪除操作一會再說。插入與刪除操作體現在主函數中,這裏只掛上插入後查詢的代碼。最後不要忘記 Splay 操作。
inline int getpre(){
int k=p[root].ch[0];
if(!k) return INF;
while(p[k].ch[1]) k=p[k].ch[1];
splay(k);return k;
}
inline int getnext(){
int k=p[root].ch[1];
if(!k) return INF;
while(p[k].ch[0]) k=p[k].ch[0];
splay(k);return k;
}
2.6 刪除操作
基本思路是要把刪除的那個節點 Splay 到根節點上來,然後合併一下左右子樹。
詳細來說,首先把節點旋轉上來,然後看這個節點的 \(cnt\) ,如果不爲 \(1\) ,那麼直接減就可以,否則我們討論如下:
-
左右子樹都爲空。
我們直接銷燬掉這個節點。
-
左右子樹有一個爲空。
把那個不空的子樹旋轉上來,銷燬跟節點。
-
左右子樹都不爲空。
我們考慮合併兩顆子樹,顯然我們需要找到左子樹中最大的那個節點來充當新樹的跟。這個節點就是現在根節點的前驅,所以我們直接查詢這個根節點的前驅,這樣前驅就被 Splay 到了根上。我們可以發現,當這個前驅被旋轉到左子樹的根節點上時,由於它是左子樹最大的,所以它沒有右兒子,通過畫圖不難得出,再前驅與根節點交換後,根節點變成了前驅的右兒子且根節點沒有左兒子,所以我們就可以像刪除鏈表上的元素一樣刪除這個節點了。
inline void delete_(int k){
getrank(k);
if(p[root].cnt>1){
p[root].cnt--;pushup(root);return;
}
if(!p[root].ch[0]&&!p[root].ch[1]){
clear(root);root=0;return;
}
if(p[root].ch[0]==0||p[root].ch[1]==0){
int which=p[root].ch[0]==0?0:1,now=root;
root=p[root].ch[which^1];p[root].fa=0;clear(now);return;
}
int now=root,pre=getpre();
p[p[now].ch[1]].fa=pre;
p[pre].ch[1]=p[now].ch[1];
clear(now);pushup(root);
}
3 總代碼
#include<bits/stdc++.h>
#define dd double
#define ld long double
#define ll long long
#define uint unsigned int
#define ull unsigned long long
#define N 3000000
#define M number
using namespace std;
const int INF=0x3f3f3f3f;
template<typename T> inline void read(T &x) {
x=0; int f=1;
char c=getchar();
for(;!isdigit(c);c=getchar()) if(c == '-') f=-f;
for(;isdigit(c);c=getchar()) x=x*10+c-'0';
x*=f;
}
struct point{
int size,cnt,val,ch[2],fa;
};
point p[N];
struct Splay{
int root,tot;
inline void pushup(int k){
p[k].size=p[p[k].ch[0]].size+p[p[k].ch[1]].size+p[k].cnt;
}
inline void clear(int k){
p[k].size=p[k].ch[0]=p[k].cnt=p[k].ch[1]=p[k].val=p[k].fa=0;
}
inline int get(int k){
return k==p[p[k].fa].ch[1];
}
inline void rotate(int k){
int y=p[k].fa,z=p[y].fa,which=get(k),which2=get(y);
p[y].ch[which]=p[k].ch[which^1];
if(p[y].ch[which]) p[p[y].ch[which]].fa=y;
p[k].ch[which^1]=y;
if(z) p[z].ch[which2]=k;
p[y].fa=k;p[k].fa=z;
pushup(y);pushup(k);
}
inline void splay(int k){
for(int fa=p[k].fa;fa=p[k].fa,fa;rotate(k)){
if(p[fa].fa) rotate(get(fa)==get(k)?fa:k);
}
root=k;
}
inline void new_node(int val){
++tot;p[tot].val=val;p[tot].size=1;p[tot].cnt=1;
}
inline void insert(int val){
if(!root){
new_node(val);root=tot;
return;
}
int k=root,fa=0;
while(1){
if(p[k].val==val){
p[k].cnt++;pushup(k);pushup(fa);
splay(k);break;
}
fa=k;k=p[k].ch[p[k].val<val];
if(!k){
new_node(val);p[fa].ch[p[fa].val<val]=tot;
p[tot].fa=fa;pushup(tot);pushup(fa);splay(tot);break;
}
}
}
inline int getrank(int val){
int rank=0,k=root;
while(k)
if(val<p[k].val) k=p[k].ch[0];
else{
rank+=p[p[k].ch[0]].size;
if(val==p[k].val){
splay(k);return rank+1;
}
rank+=p[k].cnt;k=p[k].ch[1];
}
return INF;
}
inline int getval(int rank){
int k=root;
while(k)
if(rank<=p[p[k].ch[0]].size) k=p[k].ch[0];
else{
rank-=p[p[k].ch[0]].size+p[k].cnt;
if(rank<=0){
splay(k);return p[k].val;
}
k=p[k].ch[1];
}
return INF;
}
inline int getpre(){
int k=p[root].ch[0];
if(!k) return INF;
while(p[k].ch[1]) k=p[k].ch[1];
splay(k);return k;
}
inline int getnext(){
int k=p[root].ch[1];
if(!k) return INF;
while(p[k].ch[0]) k=p[k].ch[0];
splay(k);return k;
}
inline void delete_(int k){
getrank(k);
if(p[root].cnt>1){
p[root].cnt--;pushup(root);return;
}
if(!p[root].ch[0]&&!p[root].ch[1]){
clear(root);root=0;return;
}
if(p[root].ch[0]==0||p[root].ch[1]==0){
int which=p[root].ch[0]==0?0:1,now=root;
root=p[root].ch[which^1];p[root].fa=0;clear(now);return;
}
int now=root,pre=getpre();
p[p[now].ch[1]].fa=pre;
p[pre].ch[1]=p[now].ch[1];
clear(now);pushup(root);
}
};
Splay sp;
int n;
int main(){
read(n);
for(int i=1;i<=n;i++){
int op,x;read(op);read(x);
if(op==1) sp.insert(x);
else if(op==2) sp.delete_(x);
else if(op==3) printf("%d\n",sp.getrank(x));
else if(op==4) printf("%d\n",sp.getval(x));
else if(op==5) sp.insert(x),printf("%d\n",p[sp.getpre()].val),sp.delete_(x);
else if(op==6) sp.insert(x),printf("%d\n",p[sp.getnext()].val),sp.delete_(x);
}
}