思路
二分答案+貪心是肯定的,難的是check()函數,參考了神犇的思路,給定時間k,對於每一支軍隊往在k時間內能到達的最高點走,這樣會得到最優解,先預處理出每支軍隊到達根節點的時間,每次二分時判斷多少軍隊能到達根節點(這些軍隊可以調動到沒有被控制的節點上),然後處理出沒有達到根節點的軍隊最高能到達的節點,將能到達的最高節點打標記;用DFS將標記上傳,也就是說如果該節點的所有兒子都被打了標記,那該節點也會被控制;建立兩個結構體,一個記錄可以調動的軍隊,另一個記錄尚未被控制的節點,按照降序排序,注意一點,如果軍隊所在的節點沒有被控制的話就讓他自己去控制好了,其他的只要剩餘時間大於到達時間貪心就可以了;
Code
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int MAXN=50000+10;
int head[MAXN],p[MAXN][20];
int d[MAXN],g[MAXN][20];
int arm[MAXN],r[MAXN];
bool vis[MAXN];
int n,u,v,w,m,num,sum;
struct Edge{
int next,to,w;
}edge[MAXN<<1];
struct Point{
int w,from;
}b[MAXN<<1],c[MAXN<<1];
void add(int from,int to,int w)
{
edge[++num].to=to;
edge[num].w=w;
edge[num].next=head[from];
head[from]=num;
}
void dfs(int u)
{
for(int i=head[u];i;i=edge[i].next)
if(!d[edge[i].to])
{
int to=edge[i].to;
d[to]=d[u]+1;
g[to][0]=edge[i].w;
p[to][0]=u;
dfs(to);
}
}
void init()
{
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i<=n;i++)
if(p[i][j-1])
{
p[i][j]=p[p[i][j-1]][j-1];
g[i][j]=g[i][j-1]+g[p[i][j-1]][j-1];
}
}
void work()
{
for(int i=1;i<=m;i++)
{
int f=d[arm[i]]-d[1],x=arm[i];
for(int j=(int)log2(n);j>=0;j--)
if((1<<j)&f) r[i]+=g[x][j],x=p[x][j];
}
}
void pro(int i,int res)
{
int time=0;
for(int j=(int)log2(n);j>=0;j--)
if(p[i][j]&&g[i][j]+time<=res)
{
time+=g[i][j];
i=p[i][j];
}
vis[i]=1;
}
void pushup(int u)
{
int p1=1,q=0;
for(int i=head[u];i;i=edge[i].next)
if(edge[i].to!=p[u][0])
{
pushup(edge[i].to);
p1=p1&vis[edge[i].to];
q=1;
}
if(p1&&q&&u!=1) vis[u]=1;
}
int cmp(Point a,Point b) {return a.w<b.w;}
int check(int time)
{
memset(vis,0,sizeof vis);
int cnt=0,top=0;
for(int i=1;i<=m;i++)
if(r[i]>time) pro(arm[i],time);
else {
int y=arm[i];
b[++cnt].w=time-r[i];
for(int j=(int)log2(n);j>=0;j--)
if(p[y][j]>1) y=p[y][j];
b[cnt].from=y;
}
pushup(1);
for(int i=head[1];i;i=edge[i].next)
if(!vis[edge[i].to])
{
c[++top].from=edge[i].to;
c[top].w=edge[i].w;
}
sort(b+1,b+cnt+1,cmp);
sort(c+1,c+top+1,cmp);
int j=1;c[top+1].w=0x7fffffff;
for(int i=1;i<=cnt;i++)
{
if(!vis[b[i].from]) vis[b[i].from]=1;
else if(b[i].w>=c[j].w) vis[c[j].from]=1;
while(vis[c[j].from]) j++;
}
if(j>top) return 1;
return 0;
}
int main()
{
freopen("01.in","r",stdin);
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&w);
sum+=w;
add(u,v,w);
add(v,u,w);
}
d[1]=1;
dfs(1);
init();
scanf("%d",&m);
for(int i=1;i<=m;i++)
scanf("%d",&arm[i]);
work();
int l=0,r=sum,ans=0;
while(l<=r)
{
int m=(l>>1)+(r>>1)+(l&r&1);
if(check(m)) ans=m,r=m-1;
else l=m+1;
}
printf("%d",ans);
return 0;
}