BZOJ3196 二逼平衡樹 Solution

題意:寫一個數據結構支持如下操作:

(1)區間第k大

(2)區間內求某個數的排名

(3)修改某個位置的數

(4)區間內求某個數的前趨、後繼。


Sol:以下提供兩種做法。

Sol1:線段樹套平衡樹。非常裸的做法,除詢問區間第k大複雜度爲O(log^3n),其餘操作時間複雜度爲O(log^2n).

Code1:

#include<cstdio>
#include<cstring>
#include<cstdlib>
#define INF (100000000)
#define l(x) S[x].l
#define r(x) S[x].r
#define v(x) S[x].v
#define cnt(x) S[x].cnt
#define p(x) S[x].p
#define s(x) S[x].s
#define lson(x) Q[x].lson
#define rson(x) Q[x].rson
#define dl(x) Q[x].dl
#define dr(x) Q[x].dr
#define root(x) Q[x].root
const int N = 50001;
inline int _min(int a , int b)
{
    return a < b ? a : b;
}
inline int _max(int a , int b)
{
    return a > b ? a : b;
}
struct Node
{
    int l , r , v , cnt , p , s;
}S[2000050];
int Node_Ind;
int newnode(int x)
{
    int q = ++Node_Ind;
    v(q) = x;
    cnt(q) = 1;
    p(q) = rand();
    s(q) = 1;
    return q;
}
void maintain(int &q)
{
    s(q) = s(l(q)) + cnt(q) + s(r(q));
}
void lr(int &q)
{
    int tmp = r(q);
    r(q) = l(tmp);
    l(tmp) = q;
    maintain(q);
    maintain(tmp);
    q = tmp;
}
void rr(int &q)
{
    int tmp = l(q);
    l(q) = r(tmp);
    r(tmp) = q;
    maintain(q);
    maintain(tmp);
    q = tmp;
}
void insert(int x , int &q)
{
    if (!q)
        q = newnode(x);
    else
    {
        if (x == v(q))
            ++cnt(q);
        else if (x < v(q))
        {
            insert(x , l(q));
            if (p(q) < p(l(q)))
                rr(q);
        }
        else
        {
            insert(x , r(q));
            if (p(q) < p(r(q)))
                lr(q);
        }
    }
    maintain(q);
}
void remove(int x , int &q)
{
    if (!q)
        return;
    if (x == v(q))
    {
        if (cnt(q) > 1)
            --cnt(q);
        else if (!l(q) || !r(q))
            q = (!l(q)) ? r(q) : l(q);
        else if (p(l(q)) < p(r(q)))
        {
            lr(q);
            remove(x , l(q));
        }
        else
        {
            rr(q);
            remove(x , r(q));
        }
    }
    else if (x < v(q))
        remove(x , l(q));
    else
        remove(x , r(q));
    if (q)
        maintain(q);
}
int getprev(int x , int &q)
{
    int ans = -1 << 30;
    int ins = q;
    while(ins)
    {
        if (v(ins) >= x)
            ins = l(ins);
        else
        {
            ans = v(ins);
            ins = r(ins);
        }
    }
    return ans;
}
int getnext(int x , int &q)
{
    int ans = 1 << 30;
    int ins = q;
    while(ins)
    {
        if (v(ins) <= x)
            ins = r(ins);
        else
        {
            ans = v(ins);
            ins = l(ins);
        }
    }
    return ans;
}
int getrank(int x , int &q)
{
    int ans = 0;
    int ins = q;
    while(ins)
    {
        if (x <= v(ins))
            ins = l(ins);
        else
        {
            ans += s(l(ins)) + cnt(ins);
            ins = r(ins);
        }
    }
    return ans;
}
int num[N];
struct Segment_Node
{
    int lson , rson , dl , dr , root;
};
struct Segment_Tree
{
    Segment_Node Q[150000];
    int ind;
    Segment_Tree()
    {
        memset(Q , 0 , sizeof(Q));
        ind = 0;
    }
    int build(int tl , int tr)
    {
        int q = ++ind;
        dl(q) = tl;
        dr(q) = tr;
        for(register int i = tl ; i <= tr ; ++i)
            insert(num[i] , root(q));
        if (tl == tr)
            return q;
        int mid = (tl + tr) >> 1;
        lson(q) = build(tl , mid);
        rson(q) = build(mid + 1 , tr);
        return q;
    }
    int Seg_getrank(int tl , int tr , int x , int q = 1)
    {
        if (tl <= dl(q) && dr(q) <= tr)
            return getrank(x , root(q));
        int mid = (dl(q) + dr(q)) >> 1;
        if (tl > mid)
            return Seg_getrank(tl , tr , x , rson(q));
        else if (tr <= mid)
            return Seg_getrank(tl , tr , x , lson(q));
        else
            return Seg_getrank(tl , mid , x , lson(q)) + Seg_getrank(mid + 1 , tr , x , rson(q));
    }
    int Seg_getkth(int tl , int tr , int k)
    {
        int L , R , mid;
        L = 0 , R = INF , mid = (L + R + 1 )>> 1;
        while(L < R)
        {
            if (Seg_getrank(tl , tr , mid) < k)
                L = mid;
            else
                R = mid - 1;
            mid = (L + R + 1) >> 1;
        }
        return mid;
    }
    void modify(int ins , int val , int q = 1)
    {
        remove(num[ins] , root(q));
        insert(val , root(q));
        if (dl(q) == dr(q))
            return;
        int mid = (dl(q) + dr(q)) >> 1;
        modify(ins , val , (ins <= mid) ? lson(q) : rson(q));
    }
    int Seg_getprev(int tl , int tr , int x , int q = 1)
    {
        if (tl <= dl(q) && dr(q) <= tr)
            return getprev(x , root(q));
        int mid = (dl(q) + dr(q)) >> 1;
        if (tl > mid)
            return Seg_getprev(tl , tr , x , rson(q));
        else if (tr <= mid)
            return Seg_getprev(tl , tr , x , lson(q));
        else
            return _max(Seg_getprev(tl , mid , x , lson(q)) , Seg_getprev(mid + 1 , tr ,x , rson(q)));
    }
    int Seg_getnext(int tl , int tr , int x , int q = 1)
    {
        if (tl <= dl(q) && dr(q) <= tr)
            return getnext(x , root(q));
        int mid = (dl(q) + dr(q)) >> 1;
        if (tl > mid)
            return Seg_getnext(tl , tr , x , rson(q));
        else if (tr <= mid)
            return Seg_getnext(tl , tr , x , lson(q));
        else
            return _min(Seg_getnext(tl ,mid , x , lson(q)) , Seg_getnext(mid + 1 , tr , x , rson(q)));
    }
}Ans;
int main()
{
    int n , m;
    scanf("%d%d" , &n , &m);
    register int i;
    for(i = 1 ; i <= n ; ++i)
        scanf("%d" , &num[i]);
    Ans.build(1 , n);
    int sign , a , b , x;
    for(i = 1 ; i <= m;  ++i)
    {
        scanf("%d" , &sign);
        switch (sign)
        {
            case 1:
            {
                scanf("%d%d%d" , &a , &b , &x);
                printf("%d\n" , Ans.Seg_getrank(a , b , x) + 1);
                break;
            }
            case 2:
            {
                scanf("%d%d%d" , &a , &b , &x);
                printf("%d\n" , Ans.Seg_getkth(a , b , x));
                break;
            }
            case 3:
            {
                scanf("%d%d" , &a , &x);
                Ans.modify(a , x);
                num[a] = x;
                break;
            }
            case 4:
            {
                scanf("%d%d%d" , &a , &b , &x);
                printf("%d\n" , Ans.Seg_getprev(a , b , x));
                break;
            }
            case 5:
            {
                scanf("%d%d%d" , &a , &b , &x);
                printf("%d\n" , Ans.Seg_getnext(a , b , x));
                break;
            }
        }
    }
    return 0;
}

Sol2:樹狀數組套主席樹。

回憶樹狀數組:節點i存儲的是以位置i爲結尾,長度爲lowbit(i)的連續一段的信息。樹狀數組依賴於信息滿足區間減法。

求前綴和,只需從位置i開始依次找到上一段並求和即可。

修改某位置:令i不斷加上lowbit(i)得到下一段應該被修改的是哪一段,畫個圖就容易理解了。


回憶主席樹:在值域在很小的範圍內或者允許離線的情況下,建立可持久化權值線段樹。版本i可認爲是前i個位置的數所構成的權值線段樹。

新版本的權值線段樹只需在上一個版本的基礎上新建一條路徑上的O(logn)個節點即可。

那麼求區間第k大隻需在差分後的權值線段樹上根據size決定向左或向右走即可,避免了二分。

如果支持修改,那麼樸素的主席樹無法解決問題了。

我們重新定義:類似樹狀數組,第i個版本表示第i-lowbit(i)+1~i個位置的數所構成的權值線段樹。

那麼我們詢問區間第k大的時候,以兩個位置前綴和的size差值作爲依據決定走向左子樹或右子樹。複雜度O(log^2n).

修改時,新建logn顆權值線段樹的新版本,複雜度O(log^2n).

Code2:

#include <cstdio>
#include <cstring>
#include <cctype>
#include <iostream>
#include <algorithm>
using namespace std;
 
inline int getc() {
    static const int L = 1 << 20;
    static char buf[L], *S = buf, *T = buf;
    if (S == T) {
        T = (S = buf) + fread(buf, 1, L, stdin);
        if (S == T)
            return EOF;
    }
    return *S++;
}
inline int getint() {
    int c;
    while(!isdigit(c = getc()) && c != '-');
    bool sign = c == '-';
    int tmp = sign ? 0 : c - '0';
    while(isdigit(c = getc()))
        tmp = (tmp << 1) + (tmp << 3) + c - '0';
    return sign ? -tmp : tmp;
}
 
#define SORT(S, n) sort(S + 1, S + n + 1)
 
#define N 50010
#define M 50010
int w[N], ope[M][4];
 
int Global_weigh[N + M], rank_weigh[N + M], top, id;
int getins(int x) {
    int L = 1, R = id, mid;
    while(L < R) {
        mid = (L + R) >> 1;
        if (rank_weigh[mid] < x)
            L = mid + 1;
        else
            R = mid;
    }
    return L;
}
 
#define l(x) S[x].l
#define r(x) S[x].r
#define size(x) S[x].size
struct Node {
    int l, r;
    short size;
}S[8000000];
int ind;
void Newadd(int Last, int &q, int dl, int dr, int ins, int add) {
    if (!q)
        q = ++ind;
    S[q] = S[Last];
    size(q) += add;
    if (dl == dr)
        return;
    int mid = (dl + dr) >> 1;
    if (ins <= mid)
        Newadd(l(Last), l(q), dl, mid, ins, add);
    else
        Newadd(r(Last), r(q), mid + 1, dr, ins, add);
}
int build(int dl, int dr) {
    int q = ++ind;
    if (dl == dr)
        return q;
    int mid = (dl + dr) >> 1;
    l(q) = build(dl, mid);
    r(q) = build(mid + 1, dr);
    return q;
}
 
int root[N], bit[N];
int count(int x, bool self) {
    int res = 0;
    for(; x; x -= x & -x)
        res += self ? size(bit[x]) : size(l(bit[x]));
    return res;
}
inline void setself(int x) {
    for(; x; x -= x & -x)
        bit[x] = root[x];
}
inline void setson(int x, bool d) {
    for(; x; x -= x & -x)
        bit[x] = d ? r(bit[x]) : l(bit[x]);
}
 
int getless(int dl, int dr, int ins, int a, int b) {
    if (dl == dr)
        return dl < ins ? count(b, 1) - count(a, 1) : 0;
    int mid = (dl + dr) >> 1;
    if (ins <= mid) {
        setson(a, 0), setson(b, 0);
        return getless(dl, mid, ins, a, b);
    }
    else {
        int res = count(b, 0) - count(a, 0);
        setson(a, 1), setson(b, 1);
        return res + getless(mid + 1, dr, ins, a, b);
    }
}
int getkth(int dl, int dr, int k, int a, int b) {
    if (dl == dr)
        return dl;
    int mid = (dl + dr) >> 1;
    int ls = count(b, 0) - count(a, 0);
    if (ls >= k) {
        setson(a, 0), setson(b, 0);
        return getkth(dl, mid, k, a, b);
    }
    else {
        setson(a, 1), setson(b, 1);
        return getkth(mid + 1, dr, k - ls, a, b);
    }
}
int getnum(int dl, int dr, int ins, int a, int b) {
    if (dl == dr)
        return count(b, 1) - count(a, 1);
    int mid = (dl + dr) >> 1;
    if (ins <= mid) {
        setson(a, 0), setson(b, 0);
        return getnum(dl, mid, ins, a, b);
    }
    else {
        setson(a, 1), setson(b, 1);
        return getnum(mid + 1, dr, ins, a, b);
    }
}
 
int main() {
    int n = getint(), m = getint();
     
    register int i, j;
    for(i = 1; i <= n; ++i)
        Global_weigh[++top] = w[i] = getint();
     
    for(i = 1; i <= m; ++i) {
        ope[i][0] = getint();
        if (ope[i][0] == 3) {
            ope[i][1] = getint(), ope[i][3] = getint();
            Global_weigh[++top] = ope[i][3];
        }
        else {
            ope[i][1] = getint(), ope[i][2] = getint(), ope[i][3] = getint();
            if (ope[i][0] != 2)
                Global_weigh[++top] = ope[i][3];
        }
    }
     
    //Offline-Preatments
    SORT(Global_weigh, top);
    Global_weigh[0] = -1 << 30;
    for(i = 1; i <= top; ++i)
        if (Global_weigh[i] != Global_weigh[i - 1])
            rank_weigh[++id] = Global_weigh[i];
    for(i = 1; i <= n; ++i)
        w[i] = getins(w[i]);
    for(i = 1; i <= m; ++i)
        if (ope[i][0] != 2)
            ope[i][3] = getins(ope[i][3]);
     
    //Build the Init President Tree
    root[0] = build(1, id);
    for(i = 1; i <= n; ++i)
        root[i] = ++ind;
    for(i = 1; i <= n; ++i)
        for(j = i; j <= n; j += j & -j)
            Newadd(root[j], root[j], 1, id, w[i], 1);
     
    //Answer the Questions
    for(i = 1; i <= m; ++i) {
        if (ope[i][0] == 1) {//get-rank
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            printf("%d\n", getless(1, id, ope[i][3], ope[i][1] - 1, ope[i][2]) + 1);
        }
        else if (ope[i][0] == 2) {//get-kth
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            printf("%d\n", rank_weigh[getkth(1, id, ope[i][3], ope[i][1] - 1, ope[i][2])]);
        }
        else if (ope[i][0] == 3) {//modify
            for(j = ope[i][1]; j <= n; j += j & -j)
                Newadd(root[j], root[j], 1, id, w[ope[i][1]], -1);
            w[ope[i][1]] = ope[i][3];
            for(j = ope[i][1]; j <= n; j += j & -j)
                Newadd(root[j], root[j], 1, id, w[ope[i][1]], 1);
        }
        else if (ope[i][0] == 4) {//prev
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            int less = getless(1, id, ope[i][3], ope[i][1] - 1, ope[i][2]);
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            printf("%d\n", rank_weigh[getkth(1, id, less, ope[i][1] - 1, ope[i][2])]);
        }
        else {//succ
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            int less = getless(1, id, ope[i][3], ope[i][1] - 1, ope[i][2]);
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            int num = getnum(1, id, ope[i][3], ope[i][1] - 1, ope[i][2]);
            setself(ope[i][1] - 1);
            setself(ope[i][2]);
            printf("%d\n", rank_weigh[getkth(1, id, less + num + 1, ope[i][1] - 1, ope[i][2])]);
        }
    }
     
    return 0;
}




發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章