【模板】普通平衡樹(splay)

題目鏈接:傳送門
注意splay和upd的位置就好了

//by sdfzchy
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int inf=(1<<30),N=1000100,mod=1e9+7;
int n,m;
inline int in()
{
    char tmp=getchar();
    int res=0,f=1;
    while((tmp<'0'||tmp>'9')&&tmp!='-')tmp=getchar();
    if(tmp=='-') f=-1,tmp=getchar();
    while(tmp>='0'&&tmp<='9')   res=(res<<1)+(res<<3)+(tmp^48),tmp=getchar();
    return res*f;
}

int siz[N],c[N][2],cnt[N],fa[N],val[N];
int root,sz;
inline void clr(int x) {siz[x]=cnt[x]=fa[x]=val[x]=c[x][1]=c[x][0]=0;}
inline int  gi (int x) {return c[fa[x]][1]==x;}
inline void upd(int x)
{
    if(!x) return;
    siz[x]=cnt[x];
    if(c[x][0]) siz[x]+=siz[c[x][0]];
    if(c[x][1]) siz[x]+=siz[c[x][1]];
}

void rot(int x)
{
    int old=fa[x],oldf=fa[old],o=gi(x);
    c[old][o]=c[x][!o];
    fa[c[x][!o]]=old;
    c[x][!o]=old;
    fa[old]=x;
    fa[x]=oldf;
    if(oldf) c[oldf][c[oldf][1]==old]=x;
    upd(old);upd(x);
}

void splay(int x)
{
    for(int ff;ff=fa[x];rot(x))
        if(fa[ff])
            rot((gi(x)==gi(ff))?ff:x);
    root=x;
}

void ins(int x)
{
    if(!root)
    {
        ++sz;
        cnt[sz]=siz[sz]=1;
        val[sz]=x;
        root=sz;
        return;
    }
    int pos=root,ff=0;
    while(1)
    {
        if(val[pos]==x)
        {
            cnt[pos]++;
            upd(pos);
            upd(ff);/////////
            splay(pos);////////
            break;
        }
        ff=pos;
        pos=c[pos][x>val[pos]];
        if(pos==0)
        {
            ++sz;
            cnt[sz]=siz[sz]=1;
            val[sz]=x;
            fa[sz]=ff;
            c[ff][x>val[ff]]=sz;
            upd(ff);/////////////
            splay(sz);
            break;
        }
    }
}

int rank(int x)
{
    int pos=root,ans=0;
    while(1)
    {
        if(x<val[pos]) pos=c[pos][0];
        else
        {
            ans+=c[pos][0]?siz[c[pos][0]]:0;
            if(x==val[pos]) {splay(pos);return ans+1;} 
            ans+=cnt[pos];
            pos=c[pos][1];
        }
    }
}

int kth(int x)
{
    int pos=root;
    while(1)
    {
        if(x<=siz[c[pos][0]]&&c[pos][0]) pos=c[pos][0];
        else
        {
            int tmp=(c[pos][0]?siz[c[pos][0]]:0)+cnt[pos];
            if(x<=tmp) return val[pos];
            x-=tmp;pos=c[pos][1];
        }
    }
}

int pre()
{
    int pos=c[root][0];
    while(c[pos][1]) pos=c[pos][1];
    return pos;
}

int suc()
{
    int pos=c[root][1];
    while(c[pos][0]) pos=c[pos][0];
    return pos;
}

void del(int x)
{
    int o=rank(x);
    if(cnt[root]>1) {cnt[root]--;upd(root);return;}/////////
    if(!c[root][0]&&!c[root][1]) {clr(root);root=0;return;}
    if(!c[root][0])
    {
        int old=root;
        root=c[root][1];
        fa[root]=0;
        clr(old);
        return;
    }
    if(!c[root][1])
    {
        int old=root;
        root=c[root][0];
        fa[root]=0;
        clr(old);
        return;
    }
    int ga=pre(),old=root;
    splay(ga);///////////
    fa[ga]=0;
    c[ga][1]=c[old][1];
    fa[c[old][1]]=ga;
    clr(old);
    upd(root);///////////
}

int main()
{
    n=in();
    for(int i=1,op,x;i<=n;i++)
    {
        op=in(),x=in();
        switch(op)
        {
            case 1: ins(x) ; break ;
            case 2: del(x) ; break ;
            case 3: printf("%d\n",rank(x)) ; break;
            case 4: printf("%d\n",kth(x))  ; break;
            case 5: ins(x);printf("%d\n",val[pre()]); del(x) ;break;
            case 6: ins(x); printf("%d\n",val[suc()]); del(x) ;break;
        }
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章