【二分棧優化dp】圖解二分棧優化dp

原理

什麼是二分棧

個人理解的二分棧,應該是一種類似刷表法的算法,對於每個點i,先更新自己的答案,再彈掉所有轉移不如i的區間,最後在後面的 [i+1,n][i+1,n] 的區間中去二分查找可以更新的,以自己爲最優轉移的最靠左的點j,加入區間 [j,n][j,n] ,併入棧。
而與一般的單調隊列的優化相比,就差不多是填表法和刷表法的區別

圖解

先有一個區間
1
爲了方便,假設前面的點都是由0號點轉移的,即加入區間 [1,n][1,n]
2
然後枚舉到第1個點,從0更新,然後發現後面的點都可以由1號點更新(比原來0號點更優),就加入區間 [2,n][2,n]
3
枚舉到i=2的時候發現更新的區間不再是i=0時加入的區間了,應該是i=1時候加入的區間,就把第一個區間彈掉,更新2號點的答案,再加入二號點的對應區間 [j1,n][j_1,n] (二分查找)
4

依次向後枚舉
5
6
這個時候,發現當i=5時,更新點j_3的方案中,5號點比4號點好,就彈出4號點加入的區間(因爲滿足決策單調,所以只需要判斷左端點就知道整個區間以哪個點爲決策點更優),加入5號二分出來的區間 [j4,n][j_4,n]
7
然後按以上規則依次轉移即可

什麼時候使用二分棧

首先,因爲我們加入的區間一定是 [j,n][j,n],所以一定題目中要求的是決策點之間至少需要xx距離,而如果是至多,就不方便使用二分棧了,推薦用單調隊列優化
其次,我們的棧的更新是一個連續的區間,(即,只要左端點依據當前點最優,則區間內所有元素都如此),所以需要滿足決策的單調性,(就是不能出現下面這種情況)
8
其中, [k1,k2][k_1,k_2] 是以1號點作爲最優轉移點的區間

時間複雜度

因爲每個點只可能出棧入棧一個區間,且每次查找區間的時候使用二分,故時間複雜度爲 O(nlogn)O(n\log{n})

二分棧優劣

優點

  1. 時間複雜度很優秀,將一般不帶優化 O(n2)O(n^2) 的時間優化到O(nlogn)O(n\log{n})
  2. 思維難度較低,不需要推很多式子,只需要證明決策單調即可

缺點

  1. 容易被一些強迫寫 O(n)O(n) 算法的題卡掉

例題

[BZOJ1010][HNOI2008]玩具裝箱

題目大意

給你N個玩具,要把每個玩具都打包,打包第i個到第j個玩具的長度是
x=ji+k=ijCkx = j - i + \sum_{k = i} ^ {j} C_k 價格是 (xL)2(x - L) ^ 2

其實就是價格 cost(i,j)=(ji+k=ijCkL)2cost(i,j) = (j - i + \sum_{k = i} ^ {j} C_k - L)^2

其中,L是個常量

輸入

第一行輸入兩個整數N,L.接下來N行輸入Ci

輸出

輸出最小總價格

樣例

5 4
3
4
2
1
4

限制

1<=N<=50000,1<=L,Ci<=1071 <= N <= 50000, 1 <= L, Ci <= 10 ^ 7

分析

此題的轉移式很好得:
dp[i]=min{dp[j]+cost(j+1,i),0<=j<i,1<=i<=n}dp[i] = \min\{dp[j] + cost(j+1,i),0 <= j < i,1 <= i <= n\}
但是這樣做就是 O(n2)O(n^2) 的複雜度,考慮二分棧優化。

二分棧

因爲顯然 cost(i,k)>=0cost(i,k) >= 0 還是個二次函數,如圖:
9
紅色線一段就是用 i2i_2 這個點來轉移更優的區間
所以可以用二分棧來做

代碼

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
#include <queue>

#define re register
#define digit(x) (x >= '0' && x <= '9')
#define gc getchar

typedef long long LL;

using namespace std;

LL read()
{
	LL x = 0, f = 1; char c = gc();
	while (!digit(c)){if (c == '-') f = -f; c = gc();}
	while (digit(c)) x = (x << 3) + (x << 1) + c - '0', c = gc();
	return x * f;
}

const int N = 50005;

int n;
LL L;
LL a[N], sum[N];
LL d[N];
struct Node
{
	int ind, l;
	Node(){}
	Node(int I, int L){ind = I, l = L;}
}q[N];

bool Pan(int j1, int j2, int i)
{
	LL x = i - j1 - 1 + sum[i] - sum[j1];
	LL y1 = d[j1] + (x - L) * (x - L);
	
	x = i - j2 - 1 + sum[i] - sum[j2];
	LL y2 = d[j2] + (x - L) * (x - L);
	
	return y2 <= y1;
}

int main()
{
	n = read(); L = read();
	for (re int i = 1; i <= n; i++)
		a[i] = read(),
		sum[i] = a[i] + sum[i - 1];
	
	int l = 1, r = 0;
	q[++r] = Node(0, 0);
	for (re int i = 1; i <= n; i++)
	{
		while (l < r && q[l + 1].l <= i) l++;
		LL x = i - q[l].ind - 1 + sum[i] - sum[q[l].ind];
		d[i] = d[q[l].ind] + (x - L) * (x - L);
		
		while (l <= r && Pan(q[r].ind, i, q[r].l)) r--; //彈掉前面的決策不如i的區間 
		int l1 = i, r1 = n + 1;
		if (l <= r) l1 = q[r].l; 
		while (l1 + 1 < r1)
		{
			int mid = (l1 + r1) >> 1;
			if(Pan(q[r].ind, i, mid))
				r1 = mid;
			else l1 = mid;
		}
		
		if (r1 == n + 1) continue;
		q[++r] = Node(i, r1);
	}
	
	printf("%lld\n", d[n]);
	
	return 0;
}

CSP-S 2019 Day2T2 劃分(88pts)

題目大意

將一段長度爲n的序列劃分成若干個區間,使得區間sum遞增,且所有區間的平方和最小

其實大家差不多都知道吧

分析

其實就是先二分一下對於當前點i能選的下一個最左邊的點的位置k,即:
10
其中, sum(i,k)>=sum(lai,i)sum(i,k) >= sum(la_i,i)

那麼區間 [k,n][k,n] 就是當前以i爲最優決策點的區間(因爲選到k和它後面的點的時候從i轉移肯定比從i之前的點轉移更優)。

代碼

#include <cstdio>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstring>

using namespace std;
typedef long long LL;
typedef unsigned long long ull;

#define gc getchar
#define re register
#define digit(x) (x >= '0' && x <= '9')
#define ud unsigned
#define _i128 __int128

LL read()
{
	LL x = 0, f = 1; char c = gc();
	while(!digit(c)){if (c == '-') f = -f; c = gc();}
	while(digit(c)) x = (x << 3) + (x << 1) + c -'0', c = gc();
	return x * f;
}
const int N = 4e7 + 5;

int n;
int a[N], f[N];
struct Node
{
	int pos, l;
	Node(){}
	Node(int P, int L){pos = P, l = L;}
}q[N];
LL s[N];

_i128 sqr(_i128 x)
{
    return x * x;
}

void Print(_i128 x)
{
    if(!x) return ;
    Print(x / 10);
    printf("%d", x % 10);
}

int main()
{
	n = read();int op = read();
	if(op)
    {
        static const LL mod = 1 << 30;
        static LL b[N];

        LL x, y, z, m;

        x = read(), y = read(),
        z = read(), b[1] = read(),
        b[2] = read(), m = read();

        for (re int i = 3; i <= n; i++)
            b[i] = ( x * b[i - 1] + y * b[i - 2] + z ) % mod;

        LL lp = 0, p, l, r;
        for (re int i = 1; i <= m; i++)
        {
            p = read(), l = read(), r = read();

            for (re int j = lp + 1; j <= p; j++)
                a[j] = b[j] % (r - l + 1) + l;

            lp = p;
        }
    }
    else
        for (re int i = 1; i <= n; i++)
            a[i] = read();

    for (re int i = 1; i <= n; i++)
        s[i] = s[i - 1] + a[i];

    int h = 1, t = 0;
    q[++t] = Node(0, 1);
	for (re int i = 1; i <= n; i++)
    {
        while (h < t && q[h + 1].l <= i) h++;
        f[i] = q[h].pos;
        LL p = s[i] - s[f[i]];

        int l = i, r = n + 1;
        while(l + 1 < r)
        {
        	int mid = (l + r) >> 1;
        	if(s[mid] - s[i] >= p)
        		r = mid;
        	else l = mid;
		}
		if (r == n + 1) continue;

		while (h <= t && r <= q[t].l) t--;
		q[++t] = Node(i, r);
    }

    _i128 ans = 0; int now = n;
    while(now)
    {
        ans += sqr(s[now] - s[f[now]]);
        now = f[now];
    }

    if(!ans) putchar('0');
	else Print(ans);
    putchar('\n');

	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章