【LuoguP4719】動態DP模板-樹鏈剖分+線段樹+矩陣乘法

測試地址:動態DP
做法: 本題需要用到樹鏈剖分+線段樹+矩陣乘法維護動態DP。
動態DP這個東西以前聽過,但當時沒有看懂,現在想來覺得是卡在矩陣乘法這個地方。這裏用的不是傳統的矩陣乘法。
一般的DP我們肯定會做,序列上的線性動態DP(可以用線性遞推式遞推的DP)很容易想到用線段樹+矩陣乘法優化,但最大權值獨立集這個經典樹形DP模型要動態維護的話,有兩個和上面問題不同的地方,第一是它不是序列,第二它的遞推式有個max\max,這肯定不能用一般的矩陣乘法解決。下面我們來一一解決這些問題。
首先是把樹上的問題轉化爲序列上的問題來做,顯然可以想到樹鏈剖分。根據經典的DP方程:
f(i,0)=max{f(son,0),f(son,1)}f(i,0)=\sum \max\{f(son,0),f(son,1)\}
f(i,1)=vali+f(son,0)f(i,1)=val_i+\sum f(son,0)
而轉移到序列上之後,一個點不在同一條重鏈上的其他兒子的貢獻是一定的,我們把這些貢獻記作s(i,0/1)s(i,0/1),那麼新的轉移方程爲:
f(i,0)=max{f(son,0),f(son,1)}+s(i,0)f(i,0)=\max\{f(son,0),f(son,1)\}+s(i,0)
f(i,1)=f(son,0)+s(i,1)f(i,1)=f(son,0)+s(i,1)
而在每次修改的時候,根據樹鏈剖分的性質,最多有O(logn)O(\log n)條輕邊,所以最多O(logn)O(\log n)ss會改變,這個性質對接下來的討論有很大幫助。
於是開始討論第二個問題,如何加速轉移?此時我們需要用一種奇特的矩陣乘法,一般的矩陣乘法是這樣的:
ci,j=kai,kbk,jc_{i,j}=\sum_ka_{i,k}b_{k,j}
而這題需要用到的矩陣乘法是這樣的:
ci,j=maxk{ai,k+bk,j}c_{i,j}=\max_k\{a_{i,k}+b_{k,j}\}
就是把加法變成max\max,乘法變成加法。我們發現這樣的矩陣乘法和倍增+Floyd的那個合併方式完全一樣,它具有和矩陣乘法一樣的結合律,因此我們只要維護這樣的矩陣乘法就可以了。具體轉移矩陣的寫法,我們把上面轉移方程中max\max括號外面的s(i,0)s(i,0)移到裏面,分別加在兩項中,就很顯然是上面新型矩陣乘法的模式了。如果不希望從某個東西轉移,在那個位置填一個inf-inf即可,具體的矩陣因爲用latex寫太麻煩我就不寫了。而這樣的矩陣乘法的單位矩陣是,主對角線是00,其他位置都是inf-inf,證明顯然。又根據上面的結論,轉移矩陣每次最多有O(logn)O(\log n)個改變,因此用線段樹維護單點修改即可,這樣我們就以O(8nlog2n)O(8n\log^2n)的時間複雜度解決了這一題。
以下是本人代碼:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll inf=1000000000ll*1000000000ll;
int n,m,first[100010],tot=0;
int son[100010],fa[100010],top[100010],bot[100010],siz[100010];
int pos[100010],qpos[100010],tim=0;
ll val[100010],f[100010][2],s[100010][2];
struct edge
{
	int v,next;
}e[200010];
struct matrix
{
	ll s[2][2];
}seg[400010],Ans,E,C;

void insert(int a,int b)
{
	e[++tot].v=b;
	e[tot].next=first[a];
	first[a]=tot;
}

void dfs1(int v)
{
	f[v][0]=0,f[v][1]=val[v];
	son[v]=0;siz[v]=1;
	for(int i=first[v];i;i=e[i].next)
		if (e[i].v!=fa[v])
		{
			fa[e[i].v]=v;
			dfs1(e[i].v);
			f[v][0]+=max(f[e[i].v][0],f[e[i].v][1]);
			f[v][1]+=f[e[i].v][0];
			siz[v]+=siz[e[i].v];
			if (siz[e[i].v]>siz[son[v]])
				son[v]=e[i].v;
		}
}

void dfs2(int v,int tp)
{
	top[v]=tp;
	pos[v]=++tim,qpos[tim]=v;
	if (son[v]) dfs2(son[v],tp),bot[v]=bot[son[v]];
	else bot[v]=v;
	s[v][0]=f[v][0]-max(f[son[v]][0],f[son[v]][1]);
	s[v][1]=f[v][1]-f[son[v]][0];
	for(int i=first[v];i;i=e[i].next)
		if (e[i].v!=fa[v]&&e[i].v!=son[v])
			dfs2(e[i].v,e[i].v);
}

void Mult(matrix &S,matrix A,matrix B)
{
	for(int i=0;i<2;i++)
		for(int j=0;j<2;j++)
		{
			S.s[i][j]=-inf;
			for(int k=0;k<2;k++)
				S.s[i][j]=max(S.s[i][j],A.s[i][k]+B.s[k][j]);
		}
}

void pushup(int no)
{
	Mult(seg[no],seg[no<<1],seg[no<<1|1]);
}

void buildtree(int no,int l,int r)
{
	if (l==r)
	{
		seg[no].s[0][0]=seg[no].s[0][1]=s[qpos[l]][0];
		seg[no].s[1][0]=s[qpos[l]][1];
		seg[no].s[1][1]=-inf;
		return;
	}
	int mid=(l+r)>>1;
	buildtree(no<<1,l,mid);
	buildtree(no<<1|1,mid+1,r);
	pushup(no);
}

void modify(int no,int l,int r,int x)
{
	if (l==r)
	{
		seg[no]=C;
		return;
	}
	int mid=(l+r)>>1;
	if (x<=mid) modify(no<<1,l,mid,x);
	else modify(no<<1|1,mid+1,r,x);
	pushup(no);
}

void query(int no,int l,int r,int s,int t)
{
	if (l>=s&&r<=t)
	{
		Mult(Ans,Ans,seg[no]);
		return;
	}
	int mid=(l+r)>>1;
	if (s<=mid) query(no<<1,l,mid,s,t);
	if (t>mid) query(no<<1|1,mid+1,r,s,t);
}

void Modify(int x,ll v)
{
	ll last0,last1;
	Ans=E;
	query(1,1,n,pos[top[x]],pos[bot[x]]);
	last0=max(Ans.s[0][0],Ans.s[0][1]);
	last1=max(Ans.s[1][0],Ans.s[1][1]);
	
	s[x][1]+=v-val[x];
	C.s[1][0]=s[x][1];
	val[x]=v;
	C.s[0][0]=C.s[0][1]=s[x][0];
	C.s[1][0]=s[x][1];
	C.s[1][1]=-inf;
	
	modify(1,1,n,pos[x]);
	x=top[x];
	while(x!=1)
	{
		int y=fa[x];
		Ans=E;
		query(1,1,n,pos[x],pos[bot[x]]);
		ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
		ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
		s[y][0]+=max(ans0,ans1)-max(last0,last1);
		s[y][1]+=ans0-last0;
		C.s[0][0]=C.s[0][1]=s[y][0];
		C.s[1][0]=s[y][1];
		C.s[1][1]=-inf;
		
		Ans=E;
		query(1,1,n,pos[top[y]],pos[bot[y]]);
		last0=max(Ans.s[0][0],Ans.s[0][1]);
		last1=max(Ans.s[1][0],Ans.s[1][1]);
		
		modify(1,1,n,pos[y]);
		x=top[y];
	}
}

int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
		scanf("%lld",&val[i]);
	for(int i=1;i<n;i++)
	{
		int a,b;
		scanf("%d%d",&a,&b);
		insert(a,b),insert(b,a);
	}
	
	fa[1]=siz[0]=0;
	dfs1(1);
	f[0][0]=f[0][1]=0;
	dfs2(1,1);
	buildtree(1,1,n);
	
	E.s[1][0]=E.s[0][1]=-inf;
	for(int i=1;i<=m;i++)
	{
		int x;ll y;
		scanf("%d%lld",&x,&y);
		Modify(x,y);
		Ans=E;
		query(1,1,n,pos[1],pos[bot[1]]);
		ll ans0=max(Ans.s[0][0],Ans.s[0][1]);
		ll ans1=max(Ans.s[1][0],Ans.s[1][1]);
		printf("%lld\n",max(ans0,ans1));
	}
	
	return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章