bzoj2870: 最長道路tree(邊分治)

題目
題解
邊分治

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=200002;
struct node{
	int to,ne,w;
}e[N<<1];
struct kk{
	int v,l;
}t[2][N];
int n,nn,i,v[N],x,y,h[N],tot,sz[N],rt,sum,mx,c[2];
ll ans;
vector<int>a[N];
bool vis[N];
inline char gc(){
	static char buf[100000],*p1=buf,*p2=buf;
	return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
	int x=0,fl=1;char ch=gc();
	for(;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
	for(;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
	return x*fl;
}
void add(int x,int y,int z){e[++tot]=(node){y,h[x],z},h[x]=tot;}
bool cmp(kk a,kk b){return a.v>b.v;}
void dfs1(int u,int fa){
	for (int i=h[u],v;i;i=e[i].ne)
		if ((v=e[i].to)!=fa) a[u].push_back(v),dfs1(v,u);
}
void rebuild(){
	tot=1,memset(h,0,(n+1)<<2);
	for (int i=1;i<=n;i++){
		int sz=a[i].size();
		if (sz<=2)
			for (int j=0;j<sz;j++) add(i,a[i][j],a[i][j]<=nn),add(a[i][j],i,a[i][j]<=nn);
		else{
			int o1=++n,o2=++n;
			v[o1]=v[o2]=v[i];
			add(i,o1,0),add(o1,i,0),add(i,o2,0),add(o2,i,0);
			for (int j=0;j<sz;j++) a[j&1?o1:o2].push_back(a[i][j]);
		}
	}
}
void getrt(int u,int fa){
	sz[u]=1;
	for (int i=h[u],v;i;i=e[i].ne)
		if ((v=e[i].to)!=fa && !vis[i>>1]){
			getrt(v,u);
			sz[u]+=sz[v];
			int tmp=max(sz[v],sum-sz[v]);
			if (tmp<mx) mx=tmp,rt=i;
		}
}
void dfs2(int o,int u,int fa,int len,int val){
	val=min(val,v[u]),t[o][c[o]++]=(kk){val,len};
	for (int i=h[u],v;i;i=e[i].ne)
		if ((v=e[i].to)!=fa && !vis[i>>1]) dfs2(o,v,u,len+e[i].w,val);
}
void solve(int u,int p){
	mx=1e9,sum=p,getrt(u,0);
	if (mx==1e9) return;
	int now=rt;
	c[0]=c[1]=0,vis[now>>1]=1;
	dfs2(0,e[now].to,0,0,1e9);
	dfs2(1,e[now^1].to,0,0,1e9);
	sort(t[0],t[0]+c[0],cmp);
	sort(t[1],t[1]+c[1],cmp);
	for (int i=0,j=0,len=-1e9;i<c[0];i++){
		for (;j<c[1] && t[1][j].v>=t[0][i].v;j++) len=max(len,t[1][j].l);
		ans=max(ans,1ll*t[0][i].v*(len+e[now].w+t[0][i].l+1));
	}
	for (int i=0,j=0,len=-1e9;i<c[1];i++){
		for (;j<c[0] && t[0][j].v>=t[1][i].v;j++) len=max(len,t[0][j].l);
		ans=max(ans,1ll*t[1][i].v*(len+e[now].w+t[1][i].l+1));
	}
	int SZ=sz[e[now].to];
	solve(e[now].to,SZ),solve(e[now^1].to,p-SZ);//因爲這裏,所以now必須是局部變量 
}
int main(){
	n=nn=rd();
	for (i=1;i<=n;i++) v[i]=rd();
	for (i=1;i<n;i++) x=rd(),y=rd(),add(x,y,1),add(y,x,1);
	dfs1(1,0);
	rebuild();
	solve(1,n);
	printf("%lld",ans);
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章