[SCOI2010] 序列操作(線段樹)

題目

描述

lxhgww最近收到了一個01序列,序列裏面包含了n個數,這些數要麼是0,要麼是1,現在對於這個序列有五種變換操作和詢問操作:
0 a b 把[a, b]區間內的所有數全變成0
1 a b 把[a, b]區間內的所有數全變成1
2 a b 把[a,b]區間內的所有數全部取反,也就是說把所有的0變成1,把所有的1變成0
3 a b 詢問[a, b]區間內總共有多少個1
4 a b 詢問[a, b]區間內最多有多少個連續的1
對於每一種詢問操作,lxhgww都需要給出回答,聰明的程序員們,你們能幫助他嗎?

輸入

輸入數據第一行包括2個數,n和m,分別表示序列的長度和操作數目
第二行包括n個數,表示序列的初始狀態
接下來m行,每行3個數,op, a, b,(0<=op<=4,0<=a<=b

輸出

對於每一個詢問操作,輸出一行,包括1個數,表示其對應的答案

輸入樣例

10 10
0 0 0 1 1 0 1 0 1 1
1 0 2
3 0 5
2 2 2
4 0 4
0 3 6
2 3 7
4 2 8
1 0 5
0 5 6
3 3 9

輸出樣例

5
2
6
5

說明

對於30%的數據,1<=n, m<=1000
對於100%的數據,1<=n, m<=100000


解題思路

沒有區間翻轉,那麼線段樹可以維護。

操作分析

  • 0 a b 把[a, b]區間內的所有數全變成0
  • 1 a b 把[a, b]區間內的所有數全變成1
    以上兩個均爲區間覆蓋操作,因此我們需要一個cov標記(初值爲-1表示無覆蓋,值爲0表示覆蓋爲0,值爲1表示覆蓋爲1)
  • 2 a b 把[a,b]區間內的所有數全部取反,也就是說把所有的0變成1,把所有的1變成0
    區間反轉操作,因此我們還需要一個rev標記,那麼我們得好好考慮雙標記問題了
  • 3 a b 詢問[a, b]區間內總共有多少個1
    區間查詢1的個數,其實也就是區間求和,維護一個sum即可
  • 4 a b 詢問[a, b]區間內最多有多少個連續的1
    區間查詢最大連續子序列問題,我們需要維護每個節點的mx1(最大連續1的個數)、lmx1(從左端點開始往右走最大連續1的個數)、rmx1(從右端點開始往左走最大連續1的個數),又因爲這道題有區間反轉操作,我們還要維護mx0、lmx0、rmx0。有點噁心……

線段樹節點信息

根據上面的分析,線段樹節點長這樣:

struct segTree{
    int l, r;                       //Basic info.
    int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
    int cov, rev;                   //tags
    segTree(){
        l = r = 0;
        sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
        cov = -1, rev = 0;
    }
}tr[N<<2];

雙標記問題

相關鏈接

這道題需要兩個標記,而這兩個標記又互相有影響,因此我們需要給這兩個標記定義優先級

  • 若cov優先級更高,則每次cov操作後要將rev清空(其實此時不清空也是正確的,只不過rev標記沒有了意義),pushdown時先下放cov標記
  • 若rev優先級更高,則每次rev操作後要將cov取反(cov ^= 1),這樣才能保證操作的正確性,pushdown時先下放rev標記

事實上,這兩種方式都是可以的,只不過後一種稍顯麻煩一點(兩種代碼均在文末給出)

求最大連續1的個數

這種問題也挺常見的,平衡樹的題中也有(NOI2005維護數列·題解),所以單獨拿出來說一下。
這種問題不僅要存一個mx(最大連續1的個數),還要存 lmx(從左端點開始往右走最大連續1的個數)和 rmx(從右端點開始往左走最大連續1的個數),向上更新時分是否跨越區間維護,應該還是比較好理解。

    inline void pushup(int id){
        tr[id].sum = tr[lid].sum + tr[rid].sum;
        for(int i = 0; i <= 1; i++){
            tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
            if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
            else    tr[id].lmx[i] = tr[lid].lmx[i];
            if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
            else    tr[id].rmx[i] = tr[rid].rmx[i];
        }
    }

詢問處理

在解決詢問時,我們需要節點的很多信息,所以爲了降低常數,可以把詢問函數類型定義爲線段樹結構體類型,方便處理


兩份代碼寫的時間隔了幾個月,所以變量名和風格稍有不同……

Code#1

cov優先

#include<cstdio>
#include<algorithm>

#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l+tr[id].r)>>1)
#define size(id) (tr[id].r-tr[id].l+1)

using namespace std;

const int N = 100005;
int n, m, a[N], opt, ql, qr;

struct segTree{
    int l, r;                       //Basic info.
    int sum, mx[2], lmx[2], rmx[2]; //Maintained info.
    int cov, rev;                   //tags
    segTree(){
        l = r = 0;
        sum = mx[0] = mx[1] = lmx[0] = lmx[1] = rmx[0] = rmx[1] = 0;
        cov = -1, rev = 0;
    }
}tr[N<<2];

struct OPT_segTree{
    inline void pushup(int id){
        tr[id].sum = tr[lid].sum + tr[rid].sum;
        for(int i = 0; i <= 1; i++){
            tr[id].mx[i] = max(max(tr[lid].mx[i], tr[rid].mx[i]), tr[lid].rmx[i] + tr[rid].lmx[i]);
            if(tr[lid].lmx[i] == size(lid)) tr[id].lmx[i] = tr[lid].lmx[i] + tr[rid].lmx[i];
            else    tr[id].lmx[i] = tr[lid].lmx[i];
            if(tr[rid].rmx[i] == size(rid)) tr[id].rmx[i] = tr[rid].rmx[i] + tr[lid].rmx[i];
            else    tr[id].rmx[i] = tr[rid].rmx[i];
        }
    }
    inline void pushdown(int id){
        if(!id || tr[id].l == tr[id].r) return;
        if(tr[id].cov != -1){
            tr[lid].sum = tr[id].cov * size(lid);
            tr[lid].lmx[tr[id].cov] = tr[lid].rmx[tr[id].cov] = tr[lid].mx[tr[id].cov] = size(lid);
            tr[lid].lmx[!tr[id].cov] = tr[lid].rmx[!tr[id].cov] = tr[lid].mx[!tr[id].cov] = 0;
            tr[lid].cov = tr[id].cov, tr[lid].rev = 0;
            tr[rid].sum = tr[id].cov * size(rid);
            tr[rid].lmx[tr[id].cov] = tr[rid].rmx[tr[id].cov] = tr[rid].mx[tr[id].cov] = size(rid);
            tr[rid].lmx[!tr[id].cov] = tr[rid].rmx[!tr[id].cov] = tr[rid].mx[!tr[id].cov] = 0;
            tr[rid].cov = tr[id].cov, tr[rid].rev = 0;
            tr[id].cov = -1;
        }
        if(tr[id].rev){
            swap(tr[lid].mx[0], tr[lid].mx[1]);
            swap(tr[lid].lmx[0], tr[lid].lmx[1]);
            swap(tr[lid].rmx[0], tr[lid].rmx[1]);
            tr[lid].sum = size(lid) - tr[lid].sum;
            tr[lid].rev ^= 1;
            swap(tr[rid].mx[0], tr[rid].mx[1]);
            swap(tr[rid].lmx[0], tr[rid].lmx[1]);
            swap(tr[rid].rmx[0], tr[rid].rmx[1]);
            tr[rid].sum = size(rid) - tr[rid].sum;
            tr[rid].rev ^= 1;
            tr[id].rev = 0;
        }
    }
    void build(int id, int l, int r){
        tr[id].l = l, tr[id].r = r;
        if(tr[id].l == tr[id].r){
            if(a[l] == 1)   tr[id].sum = tr[id].lmx[1] = tr[id].rmx[1] = tr[id].mx[1] = 1;
            if(a[l] == 0)   tr[id].sum = 0, tr[id].lmx[0] = tr[id].rmx[0] = tr[id].mx[0] = 1;
            return;
        }
        build(lid, l, mid);
        build(rid, mid+1, r);
        pushup(id);
    }
    void cover(int id, int l, int r, int val){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r){
            tr[id].sum = val * size(id);
            tr[id].lmx[val] = tr[id].rmx[val] = tr[id].mx[val] = size(id);
            tr[id].lmx[!val] = tr[id].rmx[!val] = tr[id].mx[!val] = 0;
            tr[id].cov = val, tr[id].rev = 0;
            return;
        }
        if(r <= mid)    cover(lid, l, r, val);
        else if(l > mid)    cover(rid, l, r, val);
        else    cover(lid, l, mid, val), cover(rid, mid+1, r, val);
        pushup(id);
    }
    void reverse(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r){
            swap(tr[id].mx[0], tr[id].mx[1]);
            swap(tr[id].lmx[0], tr[id].lmx[1]);
            swap(tr[id].rmx[0], tr[id].rmx[1]);
            tr[id].sum = size(id) - tr[id].sum;
            tr[id].rev ^= 1;
            return;
        }
        if(r <= mid)    reverse(lid, l, r);
        else if(l > mid)    reverse(rid, l, r);
        else    reverse(lid, l, mid), reverse(rid, mid+1, r);
        pushup(id);
    }
    int querySum(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r)  return tr[id].sum;
        if(r <= mid)    return querySum(lid, l, r);
        else if(l > mid)    return querySum(rid, l, r);
        else    return querySum(lid, l, mid) + querySum(rid, mid+1, r);
    }
    segTree querySub(int id, int l, int r){
        pushdown(id);
        if(tr[id].l == l && tr[id].r == r)  return tr[id];
        if(r <= mid)    return querySub(lid, l, r);
        else if(l > mid)    return querySub(rid, l, r);
        else{
            segTree L = querySub(lid, l, mid), R = querySub(rid, mid+1, r), res;
            res.l = l, res.r = r;
            res.sum = L.sum + R.sum;
            for(int i = 0; i <= 1; i++){
                res.mx[i] = max(max(L.mx[i], R.mx[i]), L.rmx[i] + R.lmx[i]);
                if(L.lmx[i] == L.r - L.l + 1)   res.lmx[i] = L.lmx[i] + R.lmx[i];
                else    res.lmx[i] = L.lmx[i];
                if(R.rmx[i] == R.r - R.l + 1)   res.rmx[i] = R.rmx[i] + L.rmx[i];
                else    res.rmx[i] = R.rmx[i];
            }
            return res;
        }
    }
}Seg;

int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
    Seg.build(1, 1, n);
    while(m--){
        scanf("%d%d%d", &opt, &ql, &qr);
        ql++, qr++;
        if(opt == 0)    Seg.cover(1, ql, qr, 0);
        else if(opt == 1)   Seg.cover(1, ql, qr, 1);
        else if(opt == 2)   Seg.reverse(1, ql, qr);
        else if(opt == 3)   printf("%d\n", Seg.querySum(1, ql, qr));
        else if(opt == 4)   printf("%d\n", Seg.querySub(1, ql, qr).mx[1]);
    }
    return 0;
}

Code#2

rev優先

#include<cstdio>
#include<algorithm>

#define lid id<<1
#define rid id<<1|1
#define mid ((tr[id].l + tr[id].r) >> 1)
#define len(id) (tr[id].r - tr[id].l + 1)

using namespace std;

inline int read(){
    int x = 0;
    bool fl = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9'){
        if(ch == '-')   fl = 0;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9'){
        x = (x << 1) + (x << 3) + ch - '0';
        ch = getchar();
    }
    return fl ? x : -x;
}

const int N = 100005;
int n, q, a[N], opt, ql, qr;

struct seg_tree{
    int l, r;
    int sum, lenl[2], lenr[2], len[2];//七個變量:多少個 1,從左、從右、區間最長連續 0/1
    int rev, cov;//兩個標記:取反標記 & 賦值標記 
    void init(){
        sum = lenl[0] = lenr[0] = lenl[1] = lenr[1] = 0;
        cov = -1, rev = 0;//注意初值 
    }
}tr[N<<2];

void pushup(int id){
    tr[id].sum = tr[lid].sum + tr[rid].sum;
    if(tr[lid].lenl[0] == len(lid)) tr[id].lenl[0] = tr[lid].lenl[0] + tr[rid].lenl[0];
    else    tr[id].lenl[0] = tr[lid].lenl[0];
    if(tr[lid].lenl[1] == len(lid)) tr[id].lenl[1] = tr[lid].lenl[1] + tr[rid].lenl[1];
    else    tr[id].lenl[1] = tr[lid].lenl[1];
    if(tr[rid].lenr[0] == len(rid)) tr[id].lenr[0] = tr[lid].lenr[0] + tr[rid].lenr[0];
    else    tr[id].lenr[0] = tr[rid].lenr[0];
    if(tr[rid].lenr[1] == len(rid)) tr[id].lenr[1] = tr[lid].lenr[1] + tr[rid].lenr[1];
    else    tr[id].lenr[1] = tr[rid].lenr[1];
    tr[id].len[0] = max(tr[lid].len[0], max(tr[rid].len[0], tr[lid].lenr[0] + tr[rid].lenl[0]));
    tr[id].len[1] = max(tr[lid].len[1], max(tr[rid].len[1], tr[lid].lenr[1] + tr[rid].lenl[1]));
}

void pushdown(int id){
    if(tr[id].l == tr[id].r)    return;
    if(tr[id].rev){
        swap(tr[lid].lenl[0], tr[lid].lenl[1]), swap(tr[rid].lenl[0], tr[rid].lenl[1]);
        swap(tr[lid].lenr[0], tr[lid].lenr[1]), swap(tr[rid].lenr[0], tr[rid].lenr[1]);
        swap(tr[lid].len[0], tr[lid].len[1]), swap(tr[rid].len[0], tr[rid].len[1]);
        tr[lid].sum = len(lid) - tr[lid].sum, tr[rid].sum = len(rid) - tr[rid].sum;
        tr[lid].rev ^= tr[id].rev, tr[rid].rev ^= tr[id].rev;
        if(tr[lid].cov != -1)   tr[lid].cov ^= tr[id].rev;
        if(tr[rid].cov != -1)   tr[rid].cov ^= tr[id].rev;
        tr[id].rev = 0;
    }
    if(tr[id].cov != -1){
        tr[lid].cov = tr[rid].cov = tr[id].cov;
        tr[lid].sum = len(lid) * tr[id].cov;
        tr[rid].sum = len(rid) * tr[id].cov;
        tr[lid].len[tr[id].cov^1] = tr[lid].lenl[tr[id].cov^1] = tr[lid].lenr[tr[id].cov^1] = 0;
        tr[rid].len[tr[id].cov^1] = tr[rid].lenl[tr[id].cov^1] = tr[rid].lenr[tr[id].cov^1] = 0;
        tr[lid].len[tr[id].cov] = tr[lid].lenl[tr[id].cov] = tr[lid].lenr[tr[id].cov] = len(lid);
        tr[rid].len[tr[id].cov] = tr[rid].lenl[tr[id].cov] = tr[rid].lenr[tr[id].cov] = len(rid);
        tr[id].cov = -1;
    }
}

void build(int id, int l, int r){
    tr[id].init();
    tr[id].l = l, tr[id].r = r;
    if(tr[id].l == tr[id].r){
        tr[id].sum = a[l];
        if(a[l] == 0){
            tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 1;
            tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 0;
        }
        else if(a[l] == 1){
            tr[id].lenl[0] = tr[id].lenr[0] = tr[id].len[0] = 0;
            tr[id].lenl[1] = tr[id].lenr[1] = tr[id].len[1] = 1;
        }
        return;
    }
    build(lid, l, mid);
    build(rid, mid+1, r);
    pushup(id);
}

void modify_cover(int id, int l, int r, int val){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r){
        tr[id].cov = val;
        tr[id].rev = 0;
        tr[id].sum = len(id) * val;
        tr[id].lenl[val] = tr[id].lenr[val] = tr[id].len[val] = len(id);
        tr[id].lenl[val^1] = tr[id].lenr[val^1] = tr[id].len[val^1] = 0;
        return;
    }
    if(r <= mid)    modify_cover(lid, l, r, val);
    else if(l > mid)    modify_cover(rid, l, r, val);
    else    modify_cover(lid, l, mid, val), modify_cover(rid, mid+1, r, val);
    pushup(id);
}

void modify_rev(int id, int l, int r){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r){
        swap(tr[id].lenl[0], tr[id].lenl[1]);
        swap(tr[id].lenr[0], tr[id].lenr[1]);
        swap(tr[id].len[0], tr[id].len[1]);
        tr[id].sum = len(id) - tr[id].sum;
        tr[id].rev ^= 1;
        if(tr[id].cov != -1)    tr[id].cov ^= 1;
        return;
    }
    if(r <= mid)    modify_rev(lid, l, r);
    else if(l > mid)    modify_rev(rid, l, r);
    else    modify_rev(lid, l, mid), modify_rev(rid, mid+1, r);
    pushup(id);
}

seg_tree query(int id, int l, int r){
    pushdown(id);
    if(tr[id].l == l && tr[id].r == r)
        return tr[id];
    if(r <= mid)    return query(lid, l, r);
    else if(l > mid)    return query(rid, l, r);
    else{
        seg_tree t, t1, t2;
        t.init(), t1.init(), t2.init();
        t1 = query(lid, l, mid);
        t2 = query(rid, mid+1, r);
        t.sum = t1.sum + t2.sum;
        if(t1.lenl[0] == len(lid))  t.lenl[0] = t1.lenl[0] + t2.lenl[0];
        else    t.lenl[0] = t1.lenl[0];
        if(t1.lenl[1] == len(lid))  t.lenl[1] = t1.lenl[1] + t2.lenl[1];
        else    t.lenl[1] = t1.lenl[1];
        if(t2.lenr[0] == len(rid))  t.lenr[0] = t1.lenr[0] + t2.lenr[0];
        else    t.lenr[0] = t2.lenr[0];
        if(t2.lenr[1] == len(rid))  t.lenr[1] = t1.lenr[1] + t2.lenr[1];
        else    t.lenr[1] = t2.lenr[1];
        t.len[0] = max(t1.len[0], max(t2.len[0], t1.lenr[0] + t2.lenl[0]));
        t.len[1] = max(t1.len[1], max(t2.len[1], t1.lenr[1] + t2.lenl[1]));
        return t;
    }
}

int main(){
    n = read(), q = read();
    for(int i = 1; i <= n; i++) a[i] = read();
    build(1, 1, n);
    while(q--){
        opt = read(), ql = read(), qr = read();
        ql++, qr++;
        if(opt == 0)    modify_cover(1, ql, qr, 0);
        else if(opt == 1)   modify_cover(1, ql, qr, 1);
        else if(opt == 2)   modify_rev(1, ql, qr);
        else{
            seg_tree t = query(1, ql, qr);
            if(opt == 3)    printf("%d\n", t.sum);
            else if(opt == 4)   printf("%d\n", t.len[1]);
        }
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章