[POJ1987]Distance Statistics(點分治)

【題目鏈接】http://poj.org/problem?id=1987
【題目大意】給定一棵樹,每條邊有權值,求距離<=k的點對數
【解題思路】樹上點分治基礎題,更新答案時用容斥原理把子樹信息一起處理
【呆馬】

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<iostream>
const int N=40001;
using namespace std;
struct st{int to,next,v;} e[N<<1];
int n,x,y,z,k,cnt,m,num,root,t,p,i,ans,fi[N],f[N],siz[N],b[N],d[N],s[N];
bool vis[N];
char ch;
void add(int x,int y,int z)
{
    e[++cnt].to=y; e[cnt].next=fi[x]; e[cnt].v=z; fi[x]=cnt;
    e[++cnt].to=x; e[cnt].next=fi[y]; e[cnt].v=z; fi[y]=cnt;
}

void getroot(int x,int fa)
{
    f[x]=0;
    siz[x]=1;
    for (int i=fi[x];i;i=e[i].next)
        if (e[i].to!=fa && !vis[e[i].to])
        {
            int y=e[i].to;
            getroot(y,x);
            f[x]=max(f[x],siz[y]);
            siz[x]+=siz[y]; 
        }
    f[x]=max(f[x],num-siz[x]);
    if (f[x]<f[root]) root=x;
}

void go(int x,int fa)
{
    ans++;
    for (int i=fi[x];i;i=e[i].next)
        if (e[i].to!=fa && !vis[e[i].to])
        {
            d[++t]=s[e[i].to]=s[x]+e[i].v;
            if (d[t]>k)
            {
                t--;
                return;
            }
            go(e[i].to,x);
        }
}

void calc(int x)
{
    m=0;
    for (int i=fi[x];i;i=e[i].next)
        if (!vis[e[i].to] && e[i].v<=k)
        {
            d[t=1]=s[e[i].to]=e[i].v;
            go(e[i].to,x);
            sort(d+1,d+t+1);
            p=t;
            for (int i=1;i<=t;i++)
            {
                b[++m]=d[i];
                for (;p>=0 && d[i]+d[p]>k;p--);
                if (p>=i) ans-=p-i+1;
            }
        }
    sort(b+1,b+m+1);
    p=m;
    for (int i=1;i<=m;i++)
    {
        for (;p>=0 && b[i]+b[p]>k;p--);
        if (b[i]>k || !p) break;
        if (p>=i) ans+=p-i+1;
    }
}

void part(int x)
{
    vis[x]=1;
    calc(x);
    for (int i=fi[x];i;i=e[i].next)
        if (!vis[e[i].to])
        {
            root=0;
            num=siz[e[i].to];
            getroot(e[i].to,0);
            part(root);
        }
}

int main()
{
        scanf("%d%d\n",&n,&x);
        for (i=1;i<n;i++)
        {
            scanf("%d%d%d %c\n",&x,&y,&z,&ch);
            add(x,y,z);
        }
        scanf("%d",&k);
        f[0]=1e9;
        num=n;
        getroot(1,0);
        part(root);
        printf("%d",ans);
}
發佈了33 篇原創文章 · 獲贊 2 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章