題目鏈接:https://codeforces.com/problemset/problem/1398/E
題目大意
你有一個集合,初始爲空。
有兩種類型的元素,一種是普通元素,一種是強化元素,兩種元素都有一個數值。
有 \(n\) 次操作,每次操作會往集合中加入一個元素或者刪除一個元素。
每次操作後,你都需要確定集合中元素的一個排列,使得排列的價值最大。
排列價值的計算規則是:
每個元素的價值就是它對應的數字,但是強化元素能夠使它在排列中後一個位置的元素的價值翻倍(即:對於排列中某一個元素來說,如果它前一個位置的元素是一個強化元素,則這個元素的價值是它本身的數值 \(\times 2\))。
舉個例子,假設現在有三個元素:
- 第 \(1\) 個元素是一個數值爲 \(5\) 的普通元素;
- 第 \(2\) 個元素是一個數值爲 \(1\) 的強化元素;
- 第 \(3\) 個元素是一個數值爲 \(8\) 的強化元素。
則:
- 如果按照第 \(1\) 個元素,第 \(2\) 個元素,第 \(3\) 個元素排列,對應的價值爲 \(5 + 1 + 2 \cdot 8 = 22\) ;
- 如果按照第 \(1\) 個元素,第 \(3\) 個元素,第 \(2\) 個元素排列,對應的價值爲 \(5 + 8 + 2 \cdot 1 = 15\) ;
- 如果按照第 \(2\) 個元素,第 \(1\) 個元素,第 \(3\) 個元素排列,對應的價值爲 \(1 + 2 \cdot 5 + 8 = 19\) ;
- 如果按照第 \(2\) 個元素,第 \(3\) 個元素,第 \(1\) 個元素排列,對應的價值爲 \(1 + 2 \cdot 8 + 2 \cdot 5 = 27\) ;
- 如果按照第 \(3\) 個元素,第 \(1\) 個元素,第 \(2\) 個元素排列,對應的價值爲 \(8 + 2 \cdot 5 + 1 = 19\) ;
- 如果按照第 \(3\) 個元素,第 \(2\) 個元素,第 \(1\) 個元素排列,對應的價值爲 \(8 + 2 \cdot 1 + 2 \cdot 5 = 20\) 。
解題思路
假設任意次操作之後,集合中有 \(m\) 個普通元素,和 \(k\) 個強化元素。則:
- 若數值最大的 \(k\) 個元素 不全是 強化元素,則答案爲:所有元素的數值和 + 數值最大的 \(k\) 個元素的數值和
- 若數值最大的 \(k\) 個元素 全是 強化元素,則答案爲:所有元素的數值和 + 數值最大的 \(k-1\) 個元素(當然它們都是強化元素)的數值和 + 數值最大的 \(1\) 個非強化元素的數值
示例程序
實現時,用:
- \(st[0]\) 保存普通元素(\(m\) 對應普通元素個數);
- \(st[1]\) 保存強化元素(\(k\) 對應強化元素個數);
- \(st2[1]\) 保存數值最大的 \(k\) 個元素;
- \(st2[0]\) 保存除了數值最大的 \(k\) 個元素以外其餘的元素;
- \(sum\) 對應 \(st2[1]\) 中數值最大的 \(k\) 個元素之和;
- \(sum2\) 對應目前所有元素之和
(命名比較隨意,主要是因爲寫的時候發現少了又加了一個;包括 \(st2[0]\) 和 \(st2[1]\) 一開始我是開了兩個堆,結果發現堆不支持隨機刪除囧)
#include <bits/stdc++.h>
using namespace std;
int n, p, d, m, k;
set<int> st[2], st2[2];
long long sum, sum2; // sum 記錄前k大值之和,sum2 記錄所有值之和
int main() {
scanf("%d", &n);
while (n--) {
scanf("%d%d", &p, &d);
if (d > 0) {
st[p].insert(d);
p ? k++ : m++;
st2[0].insert(d);
sum2 += d;
}
else { // d < 0
d = -d;
p ? k-- : m--;
st[p].erase(d);
if (st2[0].count(d)) st2[0].erase(d);
else {
st2[1].erase(d);
sum -= d;
}
sum2 -= d;
}
while (st2[0].size() && st2[1].size()) {
int x = *st2[0].rbegin();
int y = *st2[1].begin();
if (x > y) {
st2[0].erase(x);
st2[1].erase(y);
st2[0].insert(y);
st2[1].insert(x);
sum += x - y;
}
else
break;
}
while (st2[1].size() < k) {
assert(!st2[0].empty());
int x = *st2[0].rbegin();
st2[0].erase(x);
st2[1].insert(x);
sum += x;
}
while (st2[1].size() > k) {
assert(!st2[1].empty());
int x = *st2[1].begin();
st2[1].erase(x);
sum -= x;
st2[0].insert(x);
}
long long ans;
if (st[0].empty())
ans = sum2 + sum - (*st[1].begin());
else if (st[1].empty())
ans = sum2;
else if (*st[0].rbegin() < *st[1].begin())
ans = sum2 + sum - (*st[1].begin()) + (*st[0].rbegin());
else
ans = sum2 + sum;
printf("%lld\n", ans);
}
return 0;
}