看到題還以爲要二分,結果連二分都不用
枚舉斷邊,計算兩個子樹直徑,更新答案,沒了
Code:
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int res=0,f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
return res*f;
}
const int INF=1e9,N=5e3+5;
int vis[N<<1],nxt[N<<1],head[N],tot=1,c[N<<1];
inline void add(int x,int y,int z){vis[++tot]=y;nxt[tot]=head[x];head[x]=tot;c[tot]=z;}
int pt[N];
struct info{
int d,p;
info(){}
info(int _d,int _p):d(_d),p(_p){}
}dp[N];
int xx,yy,fa[N],lca,mxx=0;
void dfs_getd(int v){
pt[v]=1;dp[v]=info(0,v);
for(int i=head[v];i;i=nxt[i]){
int y=vis[i];
if(pt[y]) continue;
fa[y]=v;dfs_getd(y);
if(mxx<c[i]+dp[y].d+dp[v].d){
mxx=c[i]+dp[y].d+dp[v].d;
xx=dp[y].p;yy=dp[v].p;lca=v;
}
if(dp[y].d+c[i]>dp[v].d){
dp[v].d=dp[y].d+c[i];
dp[v].p=dp[y].p;
}
}
}
int mark[N<<1];
void dfs_mark(int v){
for(int i=head[v];i;i=nxt[i]) if(vis[i]==fa[v]){
mark[i]=mark[i^1]=1;
if(vis[i]==lca) return;
dfs_mark(vis[i]);
return;
}
}
int dep[N];
void dfs_getdep(int v){
pt[v]=1;
for(int i=head[v];i;i=nxt[i]){
int y=vis[i];
if(pt[y]) continue;
dep[y]=dep[v]+c[i];
dfs_getdep(y);
}
}
int mx[N];
int n;
inline void check(){
int ans=INF;
for(int i=2;i<=tot;i+=2) if(mark[i]){
int mn1=INF,mn2=INF,mx1=0,mx2=0;
mxx=0;
for(int i=1;i<=n;i++) pt[i]=mx[i]=dep[i]=0;
pt[vis[i]]=1;
dfs_getd(vis[i^1]);mx1=mxx;
for(int i=1;i<=n;i++) pt[i]=0;pt[vis[i]]=1;
dep[xx]=0;dfs_getdep(xx);
for(int j=1;j<=n;j++) mx[j]=max(mx[j],dep[j]);
for(int i=1;i<=n;i++) pt[i]=0;pt[vis[i]]=1;
dep[yy]=0;dfs_getdep(yy);
for(int j=1;j<=n;j++) mx[j]=max(mx[j],dep[j]);
for(int j=1;j<=n;j++) if(mx[j]) mn1=min(mn1,mx[j]);
mxx=0;
for(int i=1;i<=n;i++) pt[i]=mx[i]=dep[i]=0;
pt[vis[i^1]]=1;
dfs_getd(vis[i]);mx2=mxx;
for(int i=1;i<=n;i++) pt[i]=0;pt[vis[i^1]]=1;
dep[xx]=0;dfs_getdep(xx);
for(int j=1;j<=n;j++) mx[j]=max(mx[j],dep[j]);
for(int i=1;i<=n;i++) pt[i]=0;pt[vis[i^1]]=1;
dep[yy]=0;dfs_getdep(yy);
for(int j=1;j<=n;j++) mx[j]=max(mx[j],dep[j]);
for(int j=1;j<=n;j++) if(mx[j]) mn2=min(mn2,mx[j]);
ans=min(max(c[i]+mn1+mn2,max(mx1,mx2)),ans);
}
cout<<ans;
}
int main(){
n=read();
for(int x,y,z,i=1;i<n;i++){
x=read();y=read();z=read();
add(x,y,z);add(y,x,z);
}
fa[1]=0;mxx=0;
dfs_getd(1);dfs_mark(xx);dfs_mark(yy);
check();
return 0;
}