我去我TM沒有去掉freopen調了一下午
思路
因爲如果第分鐘可以控制住疫情,那麼第以及之後的都是可以的,所以,就可以二分了。
然後就是函數如何寫,有一個顯而易見,就是每一個軍隊都要儘量靠近根節點,這樣纔會攔掉更多的點,所以,就要把每個軍隊向上提,如果可以到達根節點,就先放在一邊,然後再處理,然後,因爲到了根節點了,所以最優方案就是放在第二層的節點上,然後,就要看看這些時間夠不夠就可以了。
然後,因爲向上提的過程中十分耗時間,於是,就可以用倍增的思想搞成級別的。
代碼
#include<bits/stdc++.h>
#define maxn 50001
using namespace std;
int n,m;
int a[maxn];
int head[maxn],to[maxn<<1],v[maxn<<1],nex[maxn<<1],k;
int t;
int f[maxn][20];
long long dist[maxn][20];
long long sum,l,r,mid;
void add(int x,int y,int z){
to[k]=y,v[k]=z;
nex[k]=head[x];
head[x]=k++;
}
int deep[maxn];
void bfs(){
queue<int> q;
q.push(1);
deep[1]=1;
while(!q.empty()){
int x=q.front();
q.pop();
for(int pos=head[x];pos!=-1;pos=nex[pos]){
int y=to[pos];
if(deep[y])continue;
q.push(y);
deep[y]=deep[x]+1;
f[y][0]=x;
dist[y][0]=v[pos];
for(int i=1;i<=t;i++){
f[y][i]=f[f[y][i-1]][i-1];
dist[y][i]=dist[y][i-1]+dist[f[y][i-1]][i-1];
}
}
}
}
int s[maxn];
struct zj{
long long t;
int x;
bool operator < (const zj &y)const{
return t<y.t;
}
}fre[maxn];
int dfs(int x){
if(s[x])return 1;
int ff=0;
for(int pos=head[x];pos!=-1;pos=nex[pos]){
if(deep[to[pos]]<deep[x])continue;
ff=1;
if(!dfs(to[pos]))return 0;
}
return ff;
}
int flag[maxn];
bool check(long long time){
memset(s,0,sizeof(s));
memset(flag,0,sizeof(flag));
int tot=0;
for(int i=1;i<=m;i++){
int x=a[i];
long long p=0;
for(int j=t;j>=0;j--){
if(f[x][j]>1&&p+dist[x][j]<=time){
p+=dist[x][j];
x=f[x][j];
}
}
if(f[x][0]==1&&p+dist[x][0]<=time)fre[++tot]=(zj){time-p-dist[x][0],x};
else s[x]=1;
}
for(int pos=head[1];pos!=-1;pos=nex[pos]){
if(!dfs(to[pos])){
flag[to[pos]]=1;
}
}
sort(fre+1,fre+1+tot);
long long aa[maxn],bb[maxn];
aa[0]=bb[0]=0;
for(int i=1;i<=tot;i++){
if(flag[fre[i].x]==1&&fre[i].t<dist[fre[i].x][0])flag[fre[i].x]=0;
else aa[++aa[0]]=fre[i].t;
}
for(int pos=head[1];pos!=-1;pos=nex[pos]){
if(flag[to[pos]]){
bb[++bb[0]]=v[pos];
}
}
if(aa[0]<bb[0])return 0;
sort(aa+1,aa+1+aa[0]);
sort(bb+1,bb+1+bb[0]);
for(int i=1,j=1;i<=aa[0];){
if(aa[i]>=bb[j])i++,j++;
else i++;
if(j>bb[0])return 1;
}
return 0;
}
int main(){
memset(head,-1,sizeof(head));
scanf("%d",&n);
t=log2(n);
for(int i=1;i<n;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
sum+=z;
add(x,y,z);
add(y,x,z);
}
bfs();
scanf("%d",&m);
for(int i=1;i<=m;i++)scanf("%d",&a[i]);
r=sum+1;l=-1;
while(l+1<r){
mid=(l+r)>>1;
if(check(mid))r=mid;
else l=mid;
}
if(r>sum)printf("-1");
else printf("%lld",r);
return 0;
}