實際上給的單向邊而不是雙向邊,先處理出每個點到 的最短距離 , 到每個點的最短距離 。 預處理權值 ,集合劃分後的每一個點的貢獻爲:,其中 爲劃分到的集合的大小。 顯然最優情況是值連續的劃分到一組,對 進行排序,就可以在序列上做線性 dp。
考慮 表示前 個分成 組的最小總距離。轉移方程爲 dp[k][i] = dp[k - 1][j] + (i - j - 1) * (sum[i] - sum[j])
,總複雜度爲 ,考慮優化:
1、不考慮選擇分 k 組這個限制,肯定儘可能多分組最優,給每個分組加上一個代價 ,每分一次組都要加上 的代價,顯然 越大,最優解的分組數越少, 越小, 最優解的分組數越大,滿足單調性,考慮二分這個代價 ,然後做沒有限制的 dp 的複雜度是 ,二分的右邊界要大一點,大到最優解可能只分一次組。(從凸包的角度考慮可能不容易看出)
2、由於權值比較小的點劃分到的 肯定更大,決策具有單調性,利用這個單調性,當計算 dp[k][i] 時,轉移範圍只要枚舉 ,因爲最後這一塊的大小肯定小於等於平均值, 次計算後,對於每一個 ,計算所有的 的複雜度是 ,最後總複雜度爲 ,這個 沒有跑滿,因此跑得比較快。
由於決策轉移點具有單調性,還可以實現到 ,且已經有 的做法 (都不會)
wqs二分優化代碼:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn],lst[maxn];
void spfa1(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
d[i] = inf;
memset(vis,0,sizeof vis);
d[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : g[top]) {
if (d[it.fir] > d[top] + it.sec) {
d[it.fir] = d[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
void spfa2(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
t[i] = inf;
memset(vis,0,sizeof vis);
t[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : h[top]) {
if (t[it.fir] > t[top] + it.sec) {
t[it.fir] = t[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
ll solve(ll x) {
for (int i = 0; i <= b; i++)
dp[i] = inf, lst[i] = tp[i] = 0;
dp[0] = 0;
for (int i = 1; i <= b; i++) {
for (int j = lst[i]; j < i; j++) {
if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x < dp[i]) {
dp[i] = dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x;
tp[i] = tp[j] + 1;
} else if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x == dp[i]) {
if (tp[i] < tp[j] + 1)
tp[i] = tp[j] + 1;
}
}
}
return tp[b];
}
int main() {
scanf("%d%d%d%d",&n,&b,&s,&r);
for (int i = 1; i <= r; i++) {
int u,v,w; scanf("%d%d%d",&u,&v,&w);
g[u].push_back(pii(v,w));
h[v].push_back(pii(u,w));
}
spfa1(b + 1); spfa2(b + 1);
for (int i = 1; i <= b; i++)
sum[i] = d[i] + t[i];
sort(sum + 1,sum + b + 1);
for (int i = 1; i <= b; i++)
sum[i] += sum[i - 1];
ll l = 0, r = 1ll << 48;
while (l < r) {
ll mid = l + r >> 1;
if (solve(mid) < s) r = mid;
else l = mid + 1;
}
solve(l - 1);
printf("%lld\n",dp[b] - s * (l - 1));
return 0;
}
決策單調性優化:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn];
void spfa1(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
d[i] = inf;
memset(vis,0,sizeof vis);
d[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : g[top]) {
if (d[it.fir] > d[top] + it.sec) {
d[it.fir] = d[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
void spfa2(int s) {
queue<int> q;
for (int i = 1; i <= n; i++)
t[i] = inf;
memset(vis,0,sizeof vis);
t[s] = 0;
q.push(s);
while (!q.empty()) {
int top = q.front();
q.pop();
vis[top] = 0;
for (auto it : h[top]) {
if (t[it.fir] > t[top] + it.sec) {
t[it.fir] = t[top] + it.sec;
if (!vis[it.fir]) {
q.push(it.fir);
vis[it.fir] = 1;
}
}
}
}
}
ll solve() {
for (int i = 0; i <= b; i++)
tp[i] = dp[i] = inf;
tp[0] = 0;
for (int k = 1; k <= s; k++) {
for (int i = 1; i <= b; i++) {
for (int j = i - i / k; j <= i - 1; j++) // i / k 是平均每個塊的大小
dp[i] = min(dp[i],tp[j] + (i - j - 1) * (sum[i] - sum[j]));
}
for (int i = 0; i <= b; i++)
tp[i] = dp[i], dp[i] = inf;
}
return tp[b];
}
int main() {
scanf("%d%d%d%d",&n,&b,&s,&r);
for (int i = 1; i <= r; i++) {
int u,v,w; scanf("%d%d%d",&u,&v,&w);
g[u].push_back(pii(v,w));
h[v].push_back(pii(u,w));
}
spfa1(b + 1); spfa2(b + 1);
for (int i = 1; i <= b; i++)
sum[i] = d[i] + t[i];
sort(sum + 1,sum + b + 1);
for (int i = 1; i <= b; i++)
sum[i] += sum[i - 1];
printf("%lld\n",solve());
return 0;
}