[學習筆記]dp凸優化/wqs二分&[八省聯考2018]林克卡特樹lct

廢話

很早就想學wqs二分,結果拖了好久。因爲以前是看了幾遍都沒有懂。。(太菜了
後來因爲計劃裏凸優化的題(比如CF321E Ciel and Gondolas,CF739E Gosha is hunting…)太多了。。又不會高深的數據結構,所以只能硬着頭皮學
網上blog這麼多,還是wqs神仙本人的論文最好懂。。網上應該會有,這邊就不放資源了

然後靈光一現就突然懂了

先上題

這道算是經典題了叭
傳送門(小心點提交可能會被封號qaq

題意就不解釋了
問題可以轉換成求樹上不相交的k+1k+1條鏈的邊權和的最大值

顯然是樹形dp啊
dp[opt][x][i]dp[opt][x][i],其中optopt表示xx節點的度數(鏈上的),所以只有可能等於0/1/20/1/2,xx表示當前節點,ii表示以xx爲根節點的子樹中有ii條鏈

然後就會有弄出一些轉移方程

這裏就不寫了(自己推一下蠻簡單的)

但是這樣的時間複雜度是O(nk)\mathcal{O(nk)}的,已經可以過掉60分了
然後大佬說這個函數是凸的
然後窩打了一張表發現是真的

那麼下面就是今天重點了
p
然後我們可以發現可以用O(n)\mathcal{O(n)}的時間求出頂點的座標
只要轉移的時候記錄一下現在的k是什麼就好了

然後我們可以新定義一個函數f(x)=dp(x)c×xf(x)=dp(x)-c\times x
很容易發現這讓函數的頂點的橫座標往左移或者往右移了
發現了什麼?這個東西就可以二分了
沒錯這個就是wqs二分了

於最後就一定能求得一個頂點是(k,f(k))(k,f(k))
所以dp(x)=f(x)+kxdp(x)=f(x)+k*x這個就是最終的答案了
時間複雜度O(nlogk)\mathcal{O(n\log k)}
是不是很簡單???(一臉天真

Code

#include <cstdio>
#include <algorithm>
#include <cstring>
#define N 300010

using namespace std;
typedef long long LL;

LL cnt, lst[N];

struct Node{
	LL to, nxt;
	LL w;
}e[N << 1];

struct Data {
	LL x, y;
	Data(LL X = 0, LL Y = 0) {
		x = X; y = Y;
	}
	inline bool operator < (const Data &o) const {
		return x < o.x || x == o.x && y > o.y;
	}
	inline Data operator + (const Data &o) const {
		return Data(x + o.x, y + o.y);
	}
	inline Data operator + (LL o) {
		return Data(x + o, y);
	}
}dp[3][N];

inline void add(LL u, LL v, LL w) {
	e[++cnt].to = v;
	e[cnt].nxt = lst[u];
	e[cnt].w = w;
	lst[u] = cnt;
}

inline Data nw(Data o, LL v) {
	return Data(o.x - v, o.y + 1);
}

inline void dfs(LL x, LL fa, LL val) {
	dp[2][x] = Data(-val, 1);
	for (LL i = lst[x]; i; i = e[i].nxt) {
		LL son = e[i].to;
		if (son == fa) continue;
		dfs(son, x, val);
		dp[2][x] = max(dp[2][x] + dp[0][son], nw(dp[1][x] + dp[1][son] + e[i].w, val));
		dp[1][x] = max(dp[1][x] + dp[0][son], dp[0][x] + dp[1][son] + e[i].w);
		dp[0][x] = dp[0][x] + dp[0][son];
	}
	dp[0][x] = max(dp[0][x], max(nw(dp[1][x], val), dp[2][x]));
}

int main() {
	LL n, k;
	scanf("%lld%lld", &n, &k);
	k++;
	LL r = 0;
	for (LL i = 1, x, y, z; i < n; ++i) {
		scanf("%lld%lld%lld", &x, &y, &z);
		add(x, y, z);
		add(y, x, z);
		r += abs(z);
	}
	LL l = -r;
	while (l <= r) {
		LL mid = l + r >> 1;
		memset(dp, 0, sizeof dp);
		dfs(1, 0, mid);
		if (dp[0][1].y <= k) r = mid - 1;
		else l = mid + 1;
	}
	memset(dp, 0, sizeof dp);
	dfs(1, 0, l);
	printf("%lld\n", l * k + dp[0][1].x);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章