【uoj150】 NOIP2015—運輸計劃

http://uoj.ac/problem/150 (題目鏈接)

題意:給出一棵樹以及m個詢問,可以將樹上一條邊的權值修改爲0,求經過這樣的修改之後最長的邊最短是多少。

Solution
  老早就聽說過這道題了,好像使用樹鏈剖分。
  先樹鏈剖分求出每個詢問的路程,最長的最短,可以用二分做。二分最長的邊的大小,也就是最後的答案,問題來了,怎麼判斷這個答案是否可行呢?
  我們記錄下所有超出當前答案的詢問的個數p,用d記錄下符合條件的邊比當前二分的答案最大大多少,並給所有詢問的兩端點u,v的sum[]加上1,給他們的最近公共祖先f的sum[]減去2。這樣做有什麼用呢?這樣就可以統計每條邊經過了多少次了。
  我們通過dfs,每經過一條邊i,就把cnts[i]加上當前節點的sum值,這就代表有多少個點會經過這條邊,回溯的時候更新sum即可。
  最後的時候如果存在一條邊被經過的次數正好等於當前詢問數p,並且這條邊的長度大於等於d,那麼就是合法的,否則不合法。
  其實這樣的話根本就不用寫樹鏈剖分,dfs一遍就搞完了。。。然而不知道爲什麼最後extra test被卡的爆空間了。。好像是爆棧?而且讀入優化也gi了,真的鬼畜。。。

代碼:

// uoj150
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<cmath>
#define LL long long
#define inf 2147483640
#define Pi acos(-1.0)
#define free(a) freopen(a".in","r",stdin),freopen(a".out","w",stdout);
using namespace std;
int getint() {
    int f=1,x=0;char ch=getchar();
    while (ch<='0' || ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0' && ch<='9') {x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

const int maxn=300010;
struct edge {int w,to,next;}e[maxn<<1];
struct tree {int l,r,s;}tr[maxn<<1];
struct ask {int u,v,dis;}q[maxn];
int bin[20],fa[maxn][20],head[maxn],deep[maxn],size[maxn],chain[maxn],pos[maxn],sum[maxn],cnts[maxn<<1];
int n,m,cnt,num;

void link(int u,int v,int w) {
    e[++cnt].to=v;e[cnt].next=head[u];head[u]=cnt;e[cnt].w=w;
    e[++cnt].to=u;e[cnt].next=head[v];head[v]=cnt;e[cnt].w=w;
}
void build(int k,int s,int t) {
    tr[k].l=s;tr[k].r=t;tr[k].s=0;
    if (s==t) return;
    int mid=(s+t)>>1;
    build(k<<1,s,mid);
    build(k<<1|1,mid+1,t);
}
void update(int k,int s,int w) {
    int l=tr[k].l,r=tr[k].r;
    if (s==l && s==r) {tr[k].s+=w;return;}
    int mid=(l+r)>>1;
    if (s<=mid) update(k<<1,s,w);
    else update(k<<1|1,s,w);
    tr[k].s=tr[k<<1].s+tr[k<<1|1].s;
}
int query(int k,int s,int t) {
    int l=tr[k].l,r=tr[k].r;
    if (s==l && t==r) return tr[k].s;
    int mid=(l+r)>>1;
    if (t<=mid) return query(k<<1,s,t);
    else if (s>mid) return query(k<<1|1,s,t);
    else return query(k<<1,s,mid)+query(k<<1|1,mid+1,t);
}
void dfs1(int x) {
    for (int i=1;i<20;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
    size[x]=1;
    for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) {
            deep[e[i].to]=deep[x]+1;
            fa[e[i].to][0]=x;
            dfs1(e[i].to);
            size[x]+=size[e[i].to];
        }
}
void dfs2(int x,int top) {
    chain[x]=top;
    pos[x]=num++;
    int k=0;
    for (int i=head[x];i;i=e[i].next) {
        if (e[i].to!=fa[x][0]) {
            if (size[e[i].to]>size[k]) k=e[i].to;
        }
        else update(1,pos[x],e[i].w);
    }
    if (!k) return;
    dfs2(k,top);
    for (int i=head[x];i;i=e[i].next)
        if (e[i].to!=fa[x][0] && e[i].to!=k) dfs2(e[i].to,e[i].to);
}
int lca(int x,int y) {
    if (deep[x]<deep[y]) swap(x,y);
    int t=deep[x]-deep[y];
    for (int i=0;bin[i]<=t;i++) if (t&bin[i]) x=fa[x][i];
    for (int i=19;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
    return x==y?x:fa[x][0];
}
int solve(int x,int f) {
    int sum=0;
    while (chain[x]!=chain[f]) {
        sum+=query(1,pos[chain[x]],pos[x]);
        x=fa[chain[x]][0];
    }
    if (pos[f]+1<=pos[x]) sum+=query(1,pos[f]+1,pos[x]);
    return sum;
}
void dfs(int x) {
    for (int i=head[x];i;i=e[i].next) if (e[i].to!=fa[x][0]) {
            dfs(e[i].to);
            sum[x]+=sum[e[i].to];
            cnts[i]=sum[e[i].to];
        }
}
bool check(int mid) {
    int d=0,p=0;
    memset(sum,0,sizeof(sum));
    for (int u=q[1].u,v=q[1].v,i=1;i<=n;i++,u=q[i].u,v=q[i].v) 
        if (q[i].dis>mid) {
            sum[u]++;
            sum[v]++;
            sum[lca(u,v)]-=2;
            p++;
            d=max(d,q[i].dis-mid);
        }
    dfs(1);
    for (int i=1;i<=cnt;i++)
        if (p==cnts[i] && e[i].w>=d) return 1;
    return 0;
}
int main() {
    bin[0]=1;for (int i=1;i<20;i++) bin[i]=bin[i-1]<<1;
    scanf("%d%d",&n,&m);//n=getint();m=getint();
    for (int u,v,w,i=1;i<n;i++) {
        //link(getint(),getint(),getint());
        scanf("%d%d%d",&u,&v,&w);
        link(u,v,w);
    }
    build(1,1,n);
    dfs1(1);
    dfs2(1,1);
    int L=0,R=-inf,ans=0;
    for (int u,v,i=1;i<=m;i++) {
        q[i].u=u=getint(),q[i].v=v=getint();
        int f=lca(u,v);
        q[i].dis=solve(u,f)+solve(v,f);
        R=max(q[i].dis,R);
    }
    while (L<=R) {
        int mid=(L+R)>>1;
        if (check(mid)) {ans=mid;R=mid-1;}
        else L=mid+1;
    }
    printf("%d",ans);
    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章