c++最近公共祖先LCA(倍增算法和tarjan)

1.倍增

找兩個點的LCA,先讓它們深度相同,然後倍增向上跳躍,跳到使他們的值不相同的最淺層的點,那麼此點的上方即是LCA。

#include<iostream>
#include<iomanip>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cmath>
#define in(x) scanf("%d",&x);
using namespace std;
int n,m,rt,d[500007],fa[500007][22];
int heade[1000007],nexte[1000007],cnt=0,to[1000007];
void build_tree(int x,int father)
{
    int k=log(d[x])/log(2);
    for(int j=1;j<=19;++j)
    fa[x][j]=fa[fa[x][j-1]][j-1];
    for(int i=heade[x];i;i=nexte[i])
    {
        int u=to[i];
        if(u!=father)
        {
            d[u]=d[x]+1;
            fa[u][0]=x;
            build_tree(u,x);
        }
    }
}
inline int read()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') w=-1;ch=getchar();
    }
    while(ch>='0'&&ch<='9') 
    {
        x=(x<<3)+(x<<1)+ch-'0';ch=getchar();
    }
    return x*w;
}
int find_lca(int l,int r)
{
    if(d[l]>d[r]) swap(l,r);//l=4,r=2
    int k=log(d[r]-d[l]+1)/log(2);
    for(int i=0;i<=k+1;++i)
    {
        if((1<<i)&(d[r]-d[l])) r=fa[r][i];
        if(d[r]==d[l]) break;
    }
    if(l==r) return l;
    k=log(d[l])/log(2);
    for(int i=k+1;i>=0;--i)
    {
        if(fa[l][i]==fa[r][i]) continue;
        l=fa[l][i];r=fa[r][i];
    }
    return fa[l][0];
}
int main()
{
    n=read();m=read();rt=read();
    for(int i=1;i<=n-1;++i)
    {
        int x,y;x=read();y=read();
        nexte[++cnt]=heade[x];heade[x]=cnt;to[cnt]=y;
        nexte[++cnt]=heade[y];heade[y]=cnt;to[cnt]=x;
    }
    d[rt]=1;fa[rt][0]=0;
    build_tree(rt,-1);
    for(int i=1;i<=m;++i)
    {
        int x,y;x=read();y=read();
        printf("%d\n",find_lca(x,y));
    }
    return 0;
}
2.tarjan算法

離線算法,將所求先儲存,然後結合並查集深搜點,如果所求的兩個點都vis[]==1那麼輸出他們的father。

不理解就畫一下圖模擬一下。

#include<iostream>
#include<iomanip>
#include<cstring>
#include<cstdio>
#include<cmath> 
#include<algorithm>
using namespace std;
int n,m,rt;
int to[1000009],head[1000009],nxt[1000009],cnt=0,cnt1=0,ans[500009],fa[500009];
int to1[1000009],head1[1000009],nxt1[1000009],mark[1000009];bool vis[500009];
inline int read()
{
    int x=0,w=1;char ch=0;
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') w=1;ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=(x<<3)+(x<<1)+ch-'0';ch=getchar();
    }
    return x*w;
} 
int find(int x)
{
    if(fa[x]==x) return x;
    else return fa[x]=find(fa[x]);
}
void onion(int a,int b)
{
    int x=find(a),y=find(b);
    if(x==y) return ;
    else fa[x]=y;
}
void find_lca(int x,int fa)
{    
    for(int i=head[x];i;i=nxt[i])
    if(to[i]!=fa)
    find_lca(to[i],x);
    vis[x]=1;
    for(int i=head1[x];i;i=nxt1[i])
    if(vis[to1[i]]==1)
    ans[mark[i]]=find(to1[i]);
    if(fa!=-1) onion(x,fa);
}
int main()
{
     n=read();m=read();rt=read();
     for(int i=1;i<=n;++i)
     fa[i]=i;
     for(int i=1;i<=n-1;++i)
     {
         int x,y;x=read();y=read();
         nxt[++cnt]=head[x];head[x]=cnt;to[cnt]=y;
         nxt[++cnt]=head[y];head[y]=cnt;to[cnt]=x;
     }
     for(int i=1;i<=m;++i)
     {
         int x,y;x=read();y=read();
         nxt1[++cnt1]=head1[x];head1[x]=cnt1;to1[cnt1]=y;mark[cnt1]=i;
         nxt1[++cnt1]=head1[y];head1[y]=cnt1;to1[cnt1]=x;mark[cnt1]=i;
     }
     find_lca(rt,-1);
     for(int i=1;i<=m;++i)
     printf("%d\n",ans[i]);
     return 0;
}



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章