題目鏈接:https://atcoder.jp/contests/abc306/tasks/abc306_e
題目大意:
有一個長度爲 \(N\) 的序列 \(A = (A_1, A_2, \ldots, A_N)\),以及一個整數 \(K\)。
初始時序列 \(A\) 的所有元素的數值均爲 \(0\)。
有 \(Q\) 次操作,每次操作給你兩個整數 \(X_i\) 和 \(Y_i\),你需要將序列 \(A\) 中第 \(X_i\) 個元素的數值修改爲 \(Y_i\)(即 \(A_{X_i} \leftarrow Y_i\)),然後輸出一個整數,這個整數的數值爲序列 \(A\) 中最小的 \(K\) 個數之和。
解題思路:
線段樹離散化之後,或者 splay tree 都是基本操作。
示例程序1(線段樹 + 離散化):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, x[maxn], y[maxn], a[maxn];
vector<int> vec;
int lsh(int val) {
return lower_bound(vec.begin(), vec.end(), val) - vec.begin() + 1;
}
// 線段樹
int tr_cnt[maxn<<2];
long long tr_sum[maxn<<2];
#define lson l, mid, rt<<1
#define rson mid+1, r, rt<<1|1
void push_up(int rt) {
tr_cnt[rt] = tr_cnt[rt<<1] + tr_cnt[rt<<1|1];
tr_sum[rt] = tr_sum[rt<<1] + tr_sum[rt<<1|1];
}
// 離散化之後的數字是p,增加了 c 個(+1 或者 -1)
void add(int p, int c, int l, int r, int rt) {
// printf("add [%d, %d] [%d, %d] %d\n", p, c, l, r, rt);
if (l == r) {
tr_cnt[rt] += c;
tr_sum[rt] += (long long) c * vec[p-1];
return;
}
int mid = (l + r) / 2;
(p <= mid) ? add(p, c, lson) : add(p, c, rson);
push_up(rt);
}
// 前k個數
long long query(int k, int l, int r, int rt) {
// printf("query %d [%d , %d] %d\n", k, l, r, rt);
if (l == r) {
assert(tr_cnt[rt] >= k);
return (long long) k * vec[l-1];
}
if (tr_cnt[rt] == k)
return tr_sum[rt];
int mid = (l + r) / 2;
if (tr_cnt[rt<<1|1] >= k)
return query(k, rson);
return tr_sum[rt<<1|1] + query(k - tr_cnt[rt<<1|1], lson);
}
int main()
{
scanf("%d%d%d", &n, &K, &Q);
vec.push_back(0);
for (int i = 0; i < Q; i++) {
scanf("%d%d", x+i, y+i);
vec.push_back(y[i]);
}
sort(vec.begin(), vec.end());
vec.erase(unique(vec.begin(), vec.end()), vec.end());
M = vec.size();
add(1, n, 1, M, 1);
for (int i = 0; i < Q; i++) {
int p = x[i], q = y[i]; // a[p] = q
add(lsh(a[p]), -1, 1, M, 1);
a[p] = q;
add(lsh(a[p]), 1, 1, M, 1);
printf("%lld\n", query(K, 1, M, 1));
}
return 0;
}
示例程序2(splay tree):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
int n, M, K, Q, a[maxn];
struct Node {
int s[2], p, v; // s[0]左兒子編號,s[1]右兒子編號,p父節點編號,v數值
int sz, cnt; // 子樹大小
long long sum; // 子樹數值之和
Node() {}
Node(int _v, int _p) {
v = _v;
p = _p;
s[0] = s[1] = 0;
sz = cnt = 1;
sum = _v;
}
} tr[maxn];
int root, idx;
void push_up(int x) {
int ls = tr[x].s[0], rs = tr[x].s[1];
tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt;
tr[x].sum = tr[ls].sum + tr[rs].sum + (long long) tr[x].cnt * tr[x].v;
}
void f_s(int p, int u, bool k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
bool k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1]==y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
// 旋轉到 x 的父節點爲 k 爲止(若k爲0,則 x 將旋轉到根節點)
void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x) ^ (tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
// 插入一個數值爲 v 的節點
void ins(int v) {
int u = root, p = 0;
while (u) {
if (tr[u].v == v)
break;
p = u, u = tr[u].s[v > tr[u].v];
}
if (u) {
tr[u].cnt++;
push_up(u);
}
else {
tr[u = ++idx] = Node(v, p);
if (p) tr[p].s[v > tr[p].v] = u;
}
splay(u, 0);
}
// 找前驅:找數值 < v 的最大的那個數
int get_pre(int v) {
int u = root, res = -1;
while (u) {
if (tr[u].v < v) res = u, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
// 找數值等於 v 的最前面(中序遍歷序號最小)那個點
int get_point(int v) {
int u = root, res = -1;
while (u) {
if (tr[u].v >= v) res = u, u = tr[u].s[0];
else u = tr[u].s[1];
}
return res;
}
// 刪除一個數值爲 v 的節點
void del(int v) {
int u1 = get_pre(v); // 找前驅
splay(u1, 0);
int u2 = get_point(v); // 查找一個數值爲 v 的節點
splay(u2, u1);
if (tr[u2].cnt > 1) {
tr[u2].cnt--;
push_up(u2);
}
else
f_s(u1, tr[u2].s[1], 1);
push_up(u1);
}
long long query(int k) {
int u = root;
long long res = 0;
if (tr[u].sz <= k) return tr[u].sum;
while (u) {
int ls = tr[u].s[0], rs = tr[u].s[1];
if (tr[rs].sz >= k)
u = rs;
else {
res += tr[rs].sum;
k -= tr[rs].sz;
if (k <= tr[u].cnt) {
res += (long long) tr[u].v * k;
break;
}
else {
res += (long long) tr[u].v * tr[u].cnt;
k -= tr[u].cnt;
}
u = ls;
}
}
return res;
}
int main()
{
scanf("%d%d%d", &n, &K, &Q);
ins(0);
for (int i = 0; i < Q; i++) {
int x, y;
scanf("%d%d", &x, &y);
if (a[x])
del(a[x]); // 刪除 delete
a[x] = y;
if (a[x])
ins(a[x]); // 插入 insert
printf("%lld\n", query(K));
}
return 0;
}