題目連接:http://www.spoj.com/problems/FTOUR2/
題目大意:給一棵n個節點的樹,有m個節點被染成黑色,讓你找一條黑色節點數不超過k個的路徑,使得路徑的權值和最大
數據範圍:n<=200000,m<=n,k<=m
解題思路:還是樹分治。第一步找到樹的重心,然後將路徑分爲不經過重心的路徑(就是在子樹裏面的路徑),以及經過重心的路徑。對於第一類路徑,可以遞歸到子樹處理;第二類路徑需要對於重心節點維護一個 f[] 數組和g[] 數組,f[i] 表示在當前枚舉的子樹的節點之前的路徑上有不超過 i 黑色節點的最長路徑是多長,g[i] 表示當前枚舉到的節點的路徑上恰好有i 個黑色節點的最長路徑長度!然後就是枚舉與中心相連的每一個節點,處理 f g 數組,並更新最大值。
思路比較容易想到,大概說一下我遇到的問題吧!在參考了各個大牛的代碼以及問題之後終於。。ac了
首先是TLE,spoj的速度是衆所周知的。。。 一個優化就是對於當前的重心,將與其相連的點,按照該點子樹中含有最多黑點的路徑的黑點數排遞增序,這樣每次維護的時候就不是O(k)的了,而是和當前的黑色節點數有關!
好不容易不在TLE了,就是RE,最有可能的就是數組越界,但是排除了數組越界之後就是遞歸暴棧了。。。
這時候可以多開幾個全局的數組,來減少遞歸時壓站的變量數!
#include<stdio.h>
#include<iostream>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<string>
#define ll long long
#define db double
#define PB push_back
using namespace std;
const int N = 200005;
const int M = N*2;
const int INTINF = 1000000000;
const ll LLINF = 0xffffffffffffffLL;
int head[N],to[M],cost[M],next[M];
int nedge;
void init()
{
memset(head,-1,sizeof(head));
nedge=0;
}
void add(int a,int b,int c)
{
to[nedge]=b,cost[nedge]=c,next[nedge]=head[a],head[a]=nedge++;
}
int tol[N],max_node[N];
int root,min_val; //min_val=0
bool bk[N],vis[N];
int father[N];
void get_root(int k,int &n)
{
tol[k]=1,max_node[k]=0;
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_root(to[i],n);
max_node[k]=max(max_node[k],tol[to[i]]);
tol[k]+=tol[to[i]];
}
}
int mm=max(max_node[k],n-tol[k]);
if(mm<min_val)
{
min_val=mm;
root=k;
}
}
int dep[N];
void get_dep(int k,int sum)
{
dep[k]=sum;
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_dep(to[i],sum+bk[to[i]]);
dep[k]=max(dep[k],dep[to[i]]);
}
}
}
ll ans,f[N],g[N];
int num[N],max_k;
bool cmp(int a,int b)
{
return dep[to[a]]<dep[to[b]];
}
void get_g(int k,int sum,ll d)
{
if(sum>max_k) return;
g[sum]=max(g[sum],d);
for(int i=head[k];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=father[k])
{
father[to[i]]=k;
get_g(to[i],sum+bk[to[i]],d+cost[i]);
}
}
}
void work(int k,int fa,int n)
{
min_val=INTINF;
father[k]=-1;
get_root(k,n);
int rt=root;
vis[rt]=true;
for(int i=head[rt];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=fa)
{
if(tol[to[i]]>tol[rt]) work(to[i],rt,n-tol[rt]);
else work(to[i],rt,tol[to[i]]);
}
}
int cnt=0;
for(int i=head[rt];i>=0;i=next[i])
{
if(!vis[to[i]]&&to[i]!=fa)
{
father[to[i]]=rt;
get_dep(to[i],bk[to[i]]);
num[cnt++]=i;
}
}
sort(num,num+cnt,cmp);
if(cnt==0) {vis[rt]=false;return;}
int mm=min(max_k,dep[to[num[cnt-1]]]);
for(int i=0;i<=mm;i++) f[i]=-LLINF;
for(int i=0;i<cnt;i++)
{
int v=to[num[i]];
int len=dep[v];
for(int j=0;j<=len;j++) g[j]=-LLINF;
father[v]=rt;
get_g(v,bk[v],cost[num[i]]);
if(i>0)
{
int t=dep[to[num[i-1]]];
for(int j=0;j<=len;j++)
{
if(j+bk[rt]<=max_k) ans=max(ans,g[j]);
if(max_k-j-bk[rt]>=0)
ans=max(ans,g[j]+f[min(t,max_k-j-bk[rt])]);
}
}
for(int j=0;j<=len;j++)
{
f[j]=max(f[j],g[j]);
if(j>0) f[j]=max(f[j],f[j-1]);
if(j+bk[rt]<=max_k) ans=max(ans,f[j]);
}
}
vis[rt]=false;
}
int main()
{
#ifdef PKWV
freopen("in.in","r",stdin);
#endif // PKWV
int n,m;
scanf("%d%d%d",&n,&max_k,&m);
for(int i=0;i<m;i++)
{
int t;
scanf("%d",&t);
bk[t]=true;
}
init();
for(int i=1;i<n;i++)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add(a,b,c),add(b,a,c);
}
ans=0;
work(1,-1,n);
cout<<ans<<endl;
return 0;
}