ABC306 E - Best Performances 題解 離散化+線段樹/splay tree

題目鏈接:https://atcoder.jp/contests/abc306/tasks/abc306_e

題目大意:

有一個長度爲 \(N\) 的序列 \(A = (A_1, A_2, \ldots, A_N)\),以及一個整數 \(K\)

初始時序列 \(A\) 的所有元素的數值均爲 \(0\)

\(Q\) 次操作,每次操作給你兩個整數 \(X_i\)\(Y_i\),你需要將序列 \(A\) 中第 \(X_i\) 個元素的數值修改爲 \(Y_i\)(即 \(A_{X_i} \leftarrow Y_i\)),然後輸出一個整數,這個整數的數值爲序列 \(A\) 中最小的 \(K\) 個數之和。

解題思路:

線段樹離散化之後,或者 splay tree 都是基本操作。

示例程序1(線段樹 + 離散化):

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, x[maxn], y[maxn], a[maxn];
vector<int> vec;

int lsh(int val) {
	return lower_bound(vec.begin(), vec.end(), val) - vec.begin() + 1;
}

// 線段樹
int tr_cnt[maxn<<2];
long long tr_sum[maxn<<2];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
	tr_cnt[rt] = tr_cnt[rt<<1] + tr_cnt[rt<<1|1];
	tr_sum[rt] = tr_sum[rt<<1] + tr_sum[rt<<1|1];
}
// 離散化之後的數字是p,增加了 c 個(+1 或者 -1) 
void add(int p, int c, int l, int r, int rt) {
//	printf("add [%d, %d] [%d, %d] %d\n", p, c, l, r, rt);
	if (l == r) {
		tr_cnt[rt] += c;
		tr_sum[rt] += (long long) c * vec[p-1];
		return;
	}
	int mid = (l + r) / 2;
	(p <= mid) ? add(p, c, lson) : add(p, c, rson);
	push_up(rt);  
}
// 前k個數 
long long query(int k, int l, int r, int rt) {
//	printf("query %d [%d , %d] %d\n", k, l, r, rt);
	if (l == r) {
		assert(tr_cnt[rt] >= k);
		return (long long) k * vec[l-1];
	}
	if (tr_cnt[rt] == k)
		return tr_sum[rt];
	int mid = (l + r) / 2;
	if (tr_cnt[rt<<1|1] >= k)
		return query(k, rson);
	return tr_sum[rt<<1|1] + query(k - tr_cnt[rt<<1|1], lson);
}

int main()
{
	scanf("%d%d%d", &n, &K, &Q);
	vec.push_back(0);
	for (int i = 0; i < Q; i++) {
		scanf("%d%d", x+i, y+i);
		vec.push_back(y[i]);
	}
	sort(vec.begin(), vec.end());
	vec.erase(unique(vec.begin(), vec.end()), vec.end());
	M = vec.size();
	add(1, n, 1, M, 1);
	for (int i = 0; i < Q; i++) {
		int p = x[i], q = y[i]; // a[p] = q
		add(lsh(a[p]), -1, 1, M, 1);
		a[p] = q;
		add(lsh(a[p]), 1, 1, M, 1);
		printf("%lld\n", query(K, 1, M, 1));
	}
	return 0;
}

示例程序2(splay tree):

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, a[maxn];

struct Node {
	int s[2], p, v;	// s[0]左兒子編號,s[1]右兒子編號,p父節點編號,v數值
	int sz, cnt; // 子樹大小
	long long sum; // 子樹數值之和
	
	Node() {}
	Node(int _v, int _p) {
		v = _v;
		p = _p;
		s[0] = s[1] = 0;
		sz = cnt = 1;
		sum = _v;
	}
} tr[maxn];
int root, idx;

void push_up(int x) {
	int ls = tr[x].s[0], rs = tr[x].s[1];
	tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt;
	tr[x].sum = tr[ls].sum + tr[rs].sum + (long long) tr[x].cnt * tr[x].v;
}

void f_s(int p, int u, bool k) {
    tr[p].s[k] = u;
    tr[u].p = p;
}

void rot(int x) {
    int y = tr[x].p, z = tr[y].p;
    bool k = tr[y].s[1] == x;
    f_s(z, x, tr[z].s[1]==y);
    f_s(y, tr[x].s[k^1], k);
    f_s(x, y, k^1);
    push_up(y), push_up(x);
}

// 旋轉到 x 的父節點爲 k 爲止(若k爲0,則 x 將旋轉到根節點) 
void splay(int x, int k) {
    while (tr[x].p != k) {
        int y = tr[x].p, z = tr[y].p;
        if (z != k)
            (tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
        rot(x);
    }
    if (!k) root = x;
}

// 插入一個數值爲 v 的節點 
void ins(int v) {
    int u = root, p = 0;
    while (u) {
    	if (tr[u].v == v)
    		break;
        p = u, u = tr[u].s[v > tr[u].v];
    }
    if (u) {
    	tr[u].cnt++;
    	push_up(u);
	}
    else {
		tr[u = ++idx] = Node(v, p);
    	if (p) tr[p].s[v > tr[p].v] = u;
	}
    splay(u, 0);
}

// 找前驅:找數值 < v 的最大的那個數
int get_pre(int v) {
	int u = root, res = -1;
	while (u) {
		if (tr[u].v < v) res = u, u = tr[u].s[1];
		else u = tr[u].s[0];
	}
	return res;
}

// 找數值等於 v 的最前面(中序遍歷序號最小)那個點
int get_point(int v) {
	int u = root, res = -1;
	while (u) {
		if (tr[u].v >= v) res = u, u = tr[u].s[0];
		else u = tr[u].s[1];
	}
	return res;
}  

// 刪除一個數值爲 v 的節點 
void del(int v) {
	int u1 = get_pre(v);	// 找前驅 
	splay(u1, 0);
	int u2 = get_point(v); // 查找一個數值爲 v 的節點
	splay(u2, u1);
	if (tr[u2].cnt > 1) {
		tr[u2].cnt--;
		push_up(u2);
	}
	else
		f_s(u1, tr[u2].s[1], 1);
	push_up(u1);
}

long long query(int k) {
	int u = root;
	long long res = 0;
	if (tr[u].sz <= k) return tr[u].sum;
	while (u) {
		int ls = tr[u].s[0], rs = tr[u].s[1];
		if (tr[rs].sz >= k)
			u = rs;
		else {
			res += tr[rs].sum;
			k -= tr[rs].sz;
			if (k <= tr[u].cnt) {
				res += (long long) tr[u].v * k;
				break;
			}
			else {
				res += (long long) tr[u].v * tr[u].cnt;
				k -= tr[u].cnt;
			}
			u = ls;
		}		
	}
	return res;
}

int main()
{
	scanf("%d%d%d", &n, &K, &Q);
	ins(0);
	for (int i = 0; i < Q; i++) {
		int x, y;
		scanf("%d%d", &x, &y);
		if (a[x])
			del(a[x]);	// 刪除 delete 
		a[x] = y;
		if (a[x])
			ins(a[x]);	// 插入 insert
		printf("%lld\n", query(K));
	}
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章