题目连接:http://www.spoj.com/problems/FTOUR2/
题目大意:给一棵n个节点的树,有m个节点被染成黑色,让你找一条黑色节点数不超过k个的路径,使得路径的权值和最大
数据范围:n<=200000,m<=n,k<=m
解题思路:还是树分治。第一步找到树的重心,然后将路径分为不经过重心的路径(就是在子树里面的路径),以及经过重心的路径。对于第一类路径,可以递归到子树处理;第二类路径需要对于重心节点维护一个 f[] 数组和g[] 数组,f[i] 表示在当前枚举的子树的节点之前的路径上有不超过 i 黑色节点的最长路径是多长,g[i] 表示当前枚举到的节点的路径上恰好有i 个黑色节点的最长路径长度!然后就是枚举与中心相连的每一个节点,处理 f g 数组,并更新最大值。
思路比较容易想到,大概说一下我遇到的问题吧!在参考了各个大牛的代码以及问题之后终于。。ac了
首先是TLE,spoj的速度是众所周知的。。。 一个优化就是对于当前的重心,将与其相连的点,按照该点子树中含有最多黑点的路径的黑点数排递增序,这样每次维护的时候就不是O(k)的了,而是和当前的黑色节点数有关!
好不容易不在TLE了,就是RE,最有可能的就是数组越界,但是排除了数组越界之后就是递归暴栈了。。。
这时候可以多开几个全局的数组,来减少递归时压站的变量数!
#include<stdio.h>
#include<iostream>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<string>
#define ll long long
#define db double
#define PB push_back
using namespace std;
const int N = 200005;
const int M = N*2;
const int INTINF = 1000000000;
const ll LLINF = 0xffffffffffffffLL;
int head[N],to[M],cost[M],next[M];
int nedge;
void init()
{
memset(head,-1,sizeof(head));
nedge=0;
}
void add(int a,int b,int c)
{
to[nedge]=b,cost[nedge]=c,next[nedge]=head[a],head[a]=nedge++;
}
int tol[N],max_node[N];
int root,min_val; //min_val=0
bool bk[N],vis[N];
int father[N];
void get_root(int k,int &n)
{
tol[k]=1,max_node[k]=0;
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_root(to[i],n);
max_node[k]=max(max_node[k],tol[to[i]]);
tol[k]+=tol[to[i]];
}
}
int mm=max(max_node[k],n-tol[k]);
if(mm<min_val)
{
min_val=mm;
root=k;
}
}
int dep[N];
void get_dep(int k,int sum)
{
dep[k]=sum;
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_dep(to[i],sum+bk[to[i]]);
dep[k]=max(dep[k],dep[to[i]]);
}
}
}
ll ans,f[N],g[N];
int num[N],max_k;
bool cmp(int a,int b)
{
return dep[to[a]]<dep[to[b]];
}
void get_g(int k,int sum,ll d)
{
if(sum>max_k) return;
g[sum]=max(g[sum],d);
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_g(to[i],sum+bk[to[i]],d+cost[i]);
}
}
}
void work(int k,int fa,int n)
{
min_val=INTINF;
father[k]=-1;
get_root(k,n);
int rt=root;
vis[rt]=true;
for(int i=head[rt];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=fa)
{
if(tol[to[i]]>tol[rt]) work(to[i],rt,n-tol[rt]);
else work(to[i],rt,tol[to[i]]);
}
}
int cnt=0;
for(int i=head[rt];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=fa)
{
father[to[i]]=rt;
get_dep(to[i],bk[to[i]]);
num[cnt++]=i;
}
}
sort(num,num+cnt,cmp);
if(cnt==0) {vis[rt]=false;return;}
int mm=min(max_k,dep[to[num[cnt-1]]]);
for(int i=0;i<=mm;i++) f[i]=-LLINF;
for(int i=0;i<cnt;i++)
{
int v=to[num[i]];
int len=dep[v];
for(int j=0;j<=len;j++) g[j]=-LLINF;
father[v]=rt;
get_g(v,bk[v],cost[num[i]]);
if(i>0)
{
int t=dep[to[num[i-1]]];
for(int j=0;j<=len;j++)
{
if(j+bk[rt]<=max_k) ans=max(ans,g[j]);
if(max_k-j-bk[rt]>=0)
ans=max(ans,g[j]+f[min(t,max_k-j-bk[rt])]);
}
}
for(int j=0;j<=len;j++)
{
f[j]=max(f[j],g[j]);
if(j>0) f[j]=max(f[j],f[j-1]);
if(j+bk[rt]<=max_k) ans=max(ans,f[j]);
}
}
vis[rt]=false;
}
int main()
{
#ifdef PKWV
freopen("in.in","r",stdin);
#endif // PKWV
int n,m;
scanf("%d%d%d",&n,&max_k,&m);
for(int i=0;i<m;i++)
{
int t;
scanf("%d",&t);
bk[t]=true;
}
init();
for(int i=1;i<n;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c),add(b,a,c);
}
ans=0;
work(1,-1,n);
cout<<ans<<endl;
return 0;
}