CSP2019Day2T3 洛谷P5666:樹的重心 (樹上倍增)

題目傳送門:https://www.luogu.org/problem/P5666


很早以前就覺得,凡是考樹的重心相關的題,到最後都變成一道模擬題。

樹的重心有許多優秀的性質,比如:
結論一:記f(node)表示以node爲根的最大子樹的大小。從無根樹上的任意一個點x出發,向相鄰的點走一步。假設某個點y的f值比x的小,那麼x向y走就相當於向樹的重心移動了一步。假設不存在這樣的y,那麼x就是重心。假設存在f(y)=f(x),那麼它們就是樹的兩個重心。
結論二:設全樹爲有根樹。記g(node)表示node爲根的子樹的重心。假設son是node的子樹中size最大的兒子(重兒子),那麼g(node)一定是g(son)的祖先,因此讓g(son)向上跳即可。這個很顯然。
由結論二可得結論三:g(node)一定在node所在的重鏈上。

基於上述三條結論,很容易得出一個從任意一點x走到當前全樹的重心的快速算法:
x先向上跳,直到跳到一個y,使得y上方的size大於y的最大子樹的size,而father(y)並不。要注意的是y可能並不存在,而此時x本身就是我們要找的。
接下來就是沿着father(y)或x的重鏈向下跳。但這樣很麻煩,簡單的方法是拿出father(y)/x的重兒子的重心向上跳。

那麼對於這道題而言,整棵樹砍掉node的子樹之後的重心,就是從node的父親重新向上跳,然後再向下跳(或者計算新的重兒子,從它的重心向上跳)。實際處理起來細節比較多,尤其是可能出現y和father(y)本身就是兩個重心之類的極端情況。

然後xjb寫,xjb調,就過了。


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;

const int maxn=300100;
const int maxl=22;
typedef long long LL;

struct edge
{
    int obj;
    edge *Next;
} e[maxn<<1];
edge *head[maxn];
int cur;

int fa[maxn][maxl];
int dep[maxn];

int st[maxn];
int ed[maxn];
int Time;

int Size[maxn];
int son1[maxn];
int son2[maxn];
int w[maxn];

LL ans;
int n,t;

void Add(int x,int y)
{
    cur++;
    e[cur].obj=y;
    e[cur].Next=head[x];
    head[x]=e+cur;
}

void Update(int son,int node)
{
    if (Size[son]>=Size[ son1[node] ])
    {
        son2[node]=son1[node];
        son1[node]=son;
    }
    else
        if (Size[son]>Size[ son2[node] ])
            son2[node]=son;
}

int Jump(int node,int sum,int limit)
{
    for (int j=maxl-1; j>=0; j--)
    {
        int x=fa[node][j];
        if (dep[x]<limit) continue;
        if (sum-Size[x]>=Size[ son1[x] ]) node=x;
    }
    return node;
}

int Get(int node,int sum)
{
    return max(sum-Size[node],Size[ son1[node] ]);
}

void Dfs(int node)
{
    st[node]=++Time;
    dep[node]=dep[ fa[node][0] ]+1;
    for (int j=1; j<maxl; j++) fa[node][j]=fa[ fa[node][j-1] ][j-1];
    Size[node]=1;

    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        if (son!=fa[node][0])
        {
            fa[son][0]=node;
            Dfs(son);
            Size[node]+=Size[son];
            Update(son,node);
        }
    }

    ed[node]=Time;
    if (!son1[node]) w[node]=node,ans+=node;
    else
    {
        int x=w[ son1[node] ];
        x=Jump(x,Size[node],dep[node]);
        if ( x!=node && Get(fa[x][0],Size[node])<Get(x,Size[node]) ) x=fa[x][0];
        w[node]=x;
        if (node!=1) ans+=x;
        if ( node!=1 && x!=node && Get(fa[x][0],Size[node])==Get(x,Size[node]) ) ans+=fa[x][0];
    }
}

int Calc(int x,int dec)
{
    int y=son1[x];
    if ( st[y]<=st[dec] && st[dec]<=ed[y] ) return max(Size[y]-Size[dec],Size[ son2[x] ]);
    else return Size[y];
}

int Jump(int node,int sum,int limit,int dec)
{
    for (int j=maxl-1; j>=0; j--)
    {
        int x=fa[node][j];
        if (dep[x]<limit) continue;
        if (sum-Size[x]>= Calc(x,dec) ) node=x;
    }
    return node;
}

struct data
{
    int u,v;
} a[4];

bool Comp(data x,data y)
{
    return x.v<y.v;
}

void Work(int node)
{
    LL tp=ans;

    int x=fa[node][0];
    x=Jump(x,n,1,node);
    //if (node==4) printf("%d\n",x);
    int y=fa[x][0];

    for (int i=0; i<4; i++) a[i].v=100000000;
    a[0].u=x;
    a[0].v=max(n-Size[x],Calc(x,node));
    if (y)
    {
        a[1].u=y;
        a[1].v=max(n-Size[y],Calc(y,node));
    }

    if ( x!=1 && n-Size[x]> Calc(x,node) ) x=y;
    int p=son1[x];
    if ( st[p]<=st[node] && st[node]<=ed[p] )
        if (!son2[x]) p=0;
        else p=son2[x];
    
    if (p)
    {
        p=w[p];
        p=Jump(p,n-Size[node],dep[x]+1);
        a[2].u=p;
        a[2].v=Get(p,n-Size[node]);

        p=fa[p][0];
        if (p!=x)
        {
            a[3].u=p;
            a[3].v=Get(p,n-Size[node]);
        }
    }

    sort(a,a+4,Comp);

    //if (node==4)
        //for (int i=0; i<4; i++) printf("%d %d\n",a[i].u,a[i].v);

    ans+=a[0].u;
    if (a[0].v==a[1].v) ans+=a[1].u;

    //printf("%I64d\n",ans-tp);
}

int main()
{
    scanf("%d",&t);
    while (t--)
    {
        scanf("%d",&n);

        cur=-1;
        for (int i=1; i<=n; i++) head[i]=NULL,son1[i]=son2[i]=0;
        for (int i=1; i<n; i++)
        {
            int x,y;
            scanf("%d%d",&x,&y);
            Add(x,y);
            Add(y,x);
        }

        Time=0;
        ans=0;
        Dfs(1);

        //printf("%I64d\n",ans);

        for (int i=2; i<=n; i++) Work(i);

        //for (int i=1; i<=n; i++) printf("%d %d ",son1[i],son2[i]);
        //printf("\n");

        printf("%lld\n",ans);
    }

    return 0;
}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章