【IOI2018】會議【笛卡爾樹】【dp】【線段樹】

題意:長度爲nn的序列,qq次詢問,每次給定一個區間,欽定區間中的一個位置xx,使得區間所有點 與xx之間的最大值(含端點) 之和 最小,輸出最小值。

n,q7.5×105n,q\leq7.5\times10^5

神仙題,不愧是IOI

首先有一個O(n2)O(n^2)的 dp

f(l,r)=min{f(l,k1)+(rk+1)hk,(kl+1)hk+f(k+1,r)}f(l,r)=\min\{f(l,k-1)+(r-k+1)h_k,(k-l+1)h_k+f(k+1,r)\}

其中kkl,rl,r中的最大值的位置(如有多個隨便選一個),即討論欽定的點在最大值的左側或右側,然後另一側的點貢獻都是最大值

我覺得我考場上能想到這步就不錯了

這個 dp 轉移已經O(1)O(1)了,也不好壓成一維,所以要麼用可持久化之類的東西強行壓狀態,要麼就不記錄無用的狀態

直接想的話兩條路都不好走,但注意到這個過程實際上是最值分治,自然地想到笛卡爾樹

哪裏自然了啊kora

具體地講,建出序列的笛卡爾樹,然後在樹上做上面dp,每個點只記錄它代表的區間的dp值,這樣就可以O(n)O(n)處理出來了。

然而就算處理出來了你仍然無法快速計算答案,因爲詢問區間可能會被拆成很多小段,你需要像平衡樹一樣沿着樹遞歸下去。而笛卡爾樹高是O(n)O(n)的,仍然可以被卡成狗。

不過思路感覺很對,考慮怎麼優化

觀察一下這個dp方程式

f(l,r)=min{f(l,k1)+(rk+1)hk,f(k+1,r)+(kl+1)hk}f(l,r)=\min\{f(l,k-1)+(r-k+1)h_k,f(k+1,r)+(k-l+1)h_k\}

套到樹上:當前子樹根結點是kk,爲了計算f(l,k1)f(l,k-1)f(k+1,r)f(k+1,r),我們需要繼續往左右子樹遞歸計算,我們這樣子是不行的

但這個f(l,k1)f(l,k-1)f(k+1,r)f(k+1,r)比較特殊:它們都有一個端點是固定的!

爲了敘述方便,下面只討論f(k+1,r)f(k+1,r),左邊的f(l,k1)f(l,k-1)是同理的

我們要是知道右子樹的區間的所有前綴的dp信息就好了

看上去很扯,但實際上是可行的!

假設我們分別知道kk的左右子樹的前綴信息,也就是知道f(l,l...k1)f(l,l...k-1)f(k+1,k+1...r)f(k+1,k+1...r),現在考慮怎麼合併f(l,l...r)f(l,l...r)

顯然左邊是不用管的

對於右邊,我們再把這個方程式搬出來。爲了看着順眼,我把rr換成了ii

f(l,i)=min{f(l,k1)+(ik+1)hk,f(k+1,i)+(kl+1)hk}f(l,i)=\min\{f(l,k-1)+(i-k+1)h_k,f(k+1,i)+(k-l+1)h_k\}

注意這個ii是在[k+1,r][k+1,r]內的,冷靜分析一下,發現這個東西就是在原來的基礎上整體加一個數,然後和一個一次函數取min\min,因爲是個區間,所以可以用線段樹維護!

而這個方程是可以找到一個分界點,使得分界點左邊取左邊的值,右邊取右邊的值。原因是ii每增加11,左邊的值固定增加hkh_k,而右邊增加f(k+1,i+1)f(k+1,i)f(k+1,i+1)-f(k+1,i),區間往右擴張一位增加的代價總不可能大於區間最大值吧……所以左邊遲早會超過右邊,而且不會被反超。線段樹上二分即可。

什麼?線段樹開不下?

但仔細想想會發現不同的位置是不會衝突的,即每個點xx維護它當前處理的左端點到xx的dp值,所以開一個線段樹就可以了。

線段樹合併應該也可以

f(l,k1)f(l,k-1)的話再開一棵線段樹,左右倒過來就可以了

然後把詢問離線,每組詢問掛在區間最大值的點上面,建笛卡爾樹的時候順便處理一下就可以了。

複雜度O(nlogn)O(n\log n)

#include <iostream>
#include <cstring>
#include <cctype>
#include <cstdio>
#include <vector>
#define MAXN 750005
using namespace std;
inline int read()
{
	int ans=0;
	char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
	return ans;
}
typedef long long ll;
int n;
struct SegmentTree
{
	#define lc p<<1
	#define rc p<<1|1
	struct node
	{
		int l,r,tag;
		ll k,b,lv,rv;
	}t[MAXN<<2];
	void build(int p,int l,int r)
	{
		t[p].l=l,t[p].r=r;
		if (l==r) return;
		int mid=(l+r)>>1;
		build(lc,l,mid),build(rc,mid+1,r);
	}
	inline void pushcov(int p,ll k,ll b)
	{
		t[p].k=k,t[p].b=b;
		t[p].lv=k*t[p].l+b,t[p].rv=k*t[p].r+b;
		t[p].tag=1;
	}
	inline void pushadd(int p,ll k,ll b)
	{
		t[p].k+=k,t[p].b+=b;
		t[p].lv+=k*t[p].l+b,t[p].rv+=k*t[p].r+b;
		if (!t[p].tag) t[p].tag=2;		
	}
	inline void pushdown(int p)
	{
		if (t[p].tag)
		{
			if (t[p].tag==1) pushcov(lc,t[p].k,t[p].b),pushcov(rc,t[p].k,t[p].b);
			if (t[p].tag==2) pushadd(lc,t[p].k,t[p].b),pushadd(rc,t[p].k,t[p].b);
			t[p].tag=t[p].k=t[p].b=0;
		}
	}
	inline void update(int p){t[p].lv=t[lc].lv,t[p].rv=t[rc].rv;} 
	ll query(int p,int k)
	{
		if (t[p].l==t[p].r) return t[p].lv;
		pushdown(p);
		if (k<=t[lc].r) return query(lc,k);
		return query(rc,k);
	}
	void search(int p,int l,int r,ll k1,ll b1,ll b2)
	{
		if (l<=t[p].l&&t[p].r<=r)
		{
			if (k1*t[p].l+b1<=t[p].lv+b2&&k1*t[p].r+b1<=t[p].rv+b2) return pushcov(p,k1,b1);
			if (t[p].lv+b2<=k1*t[p].l+b1&&t[p].rv+b2<=k1*t[p].r+b1) return pushadd(p,0,b2);
		}
		if (r<t[p].l||t[p].r<l) return;
		pushdown(p);
		search(lc,l,r,k1,b1,b2),search(rc,l,r,k1,b1,b2);
		update(p);
	}
}lt,rt;
int h[MAXN];
inline int Max(const int& x,const int& y){return h[x]>h[y]? x:y;}
int st[20][MAXN],LOG[MAXN];
inline int rmq(int l,int r)
{
	int t=LOG[r-l+1];
	return Max(st[t][l],st[t][r-(1<<t)+1]);
}
int ql[MAXN],qr[MAXN];
ll ans[MAXN];
vector<int> lis[MAXN];
void solve(int l,int r)
{
	if (l>r) return;
	int k=rmq(l,r);
	solve(l,k-1),solve(k+1,r);
	for (int i=0;i<(int)lis[k].size();i++)
	{
		int L=ql[lis[k][i]],R=qr[lis[k][i]];
		ans[lis[k][i]]=h[k]*(R-L+1ll);
		if (L<k) ans[lis[k][i]]=min(ans[lis[k][i]],rt.query(1,L)+(R-k+1ll)*h[k]);
		if (k<R) ans[lis[k][i]]=min(ans[lis[k][i]],(k-L+1ll)*h[k]+lt.query(1,R));
	}
	ll tl=rt.query(1,l),tr=lt.query(1,r);
	lt.search(1,k,r,h[k],tl+h[k]*(1ll-k),(k-l+1ll)*h[k]);
	rt.search(1,l,k,-h[k],tr+h[k]*(k+1ll),(r-k+1ll)*h[k]);
}
int main()
{
	n=read();
	int q=read();
	for (int i=1;i<=n;i++) h[i]=read();
	lt.build(1,1,n);rt.build(1,1,n);
	for (int i=1;i<=n;i++) st[0][i]=i;
	for (int j=1;j<20;j++)
		for (int i=1;i+(1<<(j-1))<=n;i++)
			st[j][i]=Max(st[j-1][i],st[j-1][i+(1<<(j-1))]);
	LOG[0]=-1;
	for (int i=1;i<=n;i++) LOG[i]=LOG[i>>1]+1;
	for (int i=1;i<=q;i++) ql[i]=read()+1,qr[i]=read()+1,lis[rmq(ql[i],qr[i])].push_back(i);
	solve(1,n);
	for (int i=1;i<=q;i++) printf("%lld\n",ans[i]);
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章