【C++】樹上的最遠點對

【來源】

51Nod-1766
vjudge

【題目描述】

n個點被n-1條邊連接成了一顆樹,給出ab和cd兩個區間,表示點的標號請你求出兩個區間內各選一點之間的最大距離,即你需要求出max{dis(i,j) |a<=i<=b,c<=j<=d}
(PS 建議使用讀入優化)

【輸入格式】

第一行一個數字 n n<=100000。 第二行到第n行每行三個數字描述路的情況,x,y,z (1<=x,y<=n,1<=z<=10000)表示x和y之間有一條長度爲z的路。 第n+1行一個數字m,表示詢問次數 m<=100000。 接下來m行,每行四個數a,b,c,d。

【輸出格式】

共m行,表示每次詢問的最遠距離

【樣例輸入】

5
1 2 1
2 3 2
1 4 3
4 5 4
1
2 3 4 5

【樣例輸出】

10

【解析】

線段樹+LCA

給你兩個區間,問各從一個區間選擇一個點,兩個點之間的最長路是多少,這裏需要注意就是如果第一個區間是a和b最遠,第二個區間是c和d最遠,那麼答案一定是ab,cd,ac,ad,bc,bd,其中一個,於是我們只要用線段樹維護合併,外加LCA求兩個點的距離即可。

【代碼】

#pragma GCC optimize(3,"Ofast","inline")
#pragma G++ optimize(3,"Ofast","inline")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>

#define RI                 register int
#define re(i,a,b)          for(RI i=a; i<=b; i++)
#define ms(i,a)            memset(a,i,sizeof(a))
#define MAX(a,b)           (((a)>(b)) ? (a):(b))
#define MIN(a,b)           (((a)<(b)) ? (a):(b))

using namespace std;

typedef long long LL;

namespace IO {
    template <typename T>
    inline void read(T &x){
        x=0; 
        char c=0; 
        T w=0;  
        while (!isdigit(c)) w|=c=='-',c=getchar();  
        while (isdigit(c)) x=(x<<3)+(x<<1)+(c^48),c=getchar();  
        if(w) x=-x;  
    }
    
    template <typename T>
    inline void write(T x) {
        if(x<0) putchar('-'),x=-x;
        if(x<10) putchar(x+'0');
            else write(x/10),putchar(x%10+'0');
    }

    template <typename T>
    inline void writesp(T x) {
        write(x);
        putchar(' ');
    }
    
    template <typename T>
    inline void writeln(T x) {
        write(x);
        putchar('\n');
    }
} 

using IO::read;
using IO::write;
using IO::writesp;
using IO::writeln;

const int N=1e5+5;
const int inf=1e9;

struct Edge {
    int to,nt,w;
} e[N<<1];

struct Ans {
    int len;
    int a[2];
} t[N<<2];

int n,m,cnt,sum;
int h[N],dep[N],f[N][20],tin[N],tout[N];

inline void add(int a,int b,int c) {
    e[cnt]=(Edge){b,h[a],c};
    h[a]=cnt++;
}

void dfs(int x,int fa,int d) {
    f[x][0]=fa;
    tin[x]=++sum;
    dep[x]=d;
    for(int i=h[x]; i!=-1; i=e[i].nt) {
        int v=e[i].to;
        if(v==fa) continue;
        dfs(v,x,d+e[i].w);
    }
    tout[x]=++sum;
}

inline int ancestor(int x,int y) {
    return tin[x]<=tin[y] && tout[y]<=tout[x];
}

inline int lca(int x,int y) {
    if(ancestor(x,y)) return x;
    if(ancestor(y,x)) return y;
    for(int i=16; i>=0; i--) 
        if(!ancestor(f[x][i],y)) x=f[x][i];
    return f[x][0];
}

inline int dist(int x,int y) {
    int k=lca(x,y);
    return dep[x]+dep[y]-(dep[k]<<1);
}

#define lch (o<<1)
#define rch (o<<1|1)
#define mid ((l+r)>>1)

void pushup(int o) {
    t[o]=t[lch];
    if(t[o].len<t[rch].len) t[o]=t[rch];
    for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
        int tmp=dist(t[lch].a[i],t[rch].a[j]);
        if(tmp>t[o].len) {
            t[o].len=tmp;
            t[o].a[0]=t[lch].a[i];
            t[o].a[1]=t[rch].a[j];
        }
    }
}

void build(int o,int l,int r) {
    if(l==r) {
        t[o].len=0;
        t[o].a[0]=l;
        t[o].a[1]=r;
        return;
    }
    build(lch,l,mid);
    build(rch,mid+1,r);
    pushup(o);
}

Ans query(int o,int l,int r,int ll,int rr) {
    if(l==ll && r==rr) return t[o];
    if(rr<=mid) return query(lch,l,mid,ll,rr);
        else if(ll>mid) return query(rch,mid+1,r,ll,rr);
            else {
                Ans la=query(lch,l,mid,ll,mid);
                Ans ra=query(rch,mid+1,r,mid+1,rr);
                Ans ta;
                if(la.len>ra.len) ta=la;
                    else ta=ra;
                for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
                    int tmp=dist(la.a[i],ra.a[j]);
                    if(tmp>ta.len) {
                        ta.len=tmp;
                        ta.a[0]=la.a[i];
                        ta.a[1]=ra.a[j];
                    }
                }
                return ta;
            }
}

int main() {
    read(n);
    memset(h,-1,sizeof(h));
    for(int i=1; i<n; i++) {
        int x,y,z;
        read(x);
        read(y);
        read(z);
        add(x,y,z);
        add(y,x,z);
    }
    dfs(1,1,0);
    for(int j=1; j<=16; j++) for(int i=1; i<=n; i++)
        f[i][j]=f[f[i][j-1]][j-1];
    build(1,1,n);
    read(m);
    while(m--) {
        int a,b,c,d;
        read(a);
        read(b);
        read(c);
        read(d);
        Ans la,ra;
        int ans=0;
        la=query(1,1,n,a,b);
        ra=query(1,1,n,c,d);
        for(int i=0; i<=1; i++) for(int j=0; j<=1; j++) {
            int tmp=dist(la.a[i],ra.a[j]);
            ans=MAX(ans,tmp);
        }
        writeln(ans);
    }
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章