LCA問題,
先輸入n,和q;表示節點的個數和操作的個數
然後一行n個數,表示這n個節點的權值
然後n-1行,每行兩個數,表示節點u與節點v相連
最後q行,表示操作 每行三個數,k,u,v;
當k=0時,將u點的權值改爲v
當k>0時,輸出從u->v路上第K大的權值
首先求出節點u和節點v的最近公共祖先lc,然後將節點u到lc的點的權值記錄到p數組中,再將v到lc的權值記錄到p數組中,(祖先節點lc被記錄了兩次,需要刪除一次),如果數組p的長度小於k則輸出"invalid request!",否則就將p數組從大到小排序,輸出p[k-1]
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<string.h>
#include<map>
#include<math.h>
#include<queue>
#include<vector>
using namespace std;
#define nn 80100
#define inf 0x7fffffff
#define ll long long
vector<int>e[nn];
int vis[nn],first[nn],node[nn<<1],dep[nn<<1],fa[nn];
int p[nn<<1],dp[nn<<1][20],val[nn];
int _pow[25];
//node記錄DFS序
//first記錄每個點第一次在DFS序中出現的位置
//dep記錄DFS序中每個點的深度
//fa記錄每個幾點的父親節點
//val記錄每個位置的權重
//dp[i][j]表示在i到i+(1<<j)這個區間內dep值最小的下標
void dfs(int &index,int u,int d,int par)
{
index++;
vis[u]=1;
first[u]=index;
node[index]=u;
dep[index]=d;
fa[u]=par;
for(int i=0;i<e[u].size();i++)
{
if(!vis[e[u][i]])
{
dfs(index,e[u][i],d+1,u);
index++;
node[index]=u;
dep[index]=d;
}
}
}
void rmq_init(int n)
{
//int m=(int)(log((double)(n*1.0))/long(2.0));
int m=0;
while(_pow[m+1]<=n) m++;
//printf("m==%d\n",m);
for(int i=1;i<=n;i++)
dp[i][0]=i;
for(int j=1;j<=m;j++)
{
for(int i=1;(i+_pow[j]-1)<=n;i++)
{
int a=dp[i][j-1];
int b=dp[i+_pow[j-1]][j-1];
dp[i][j] = dep[a]<dep[b] ? a:b;
}
}
}
int rmq(int x,int y)
{
//int m=(int)(log((double)(y-x+1))/long(2.0));
int m=0;
while(_pow[m+1]<=(y-x+1)) m++;
//printf("m===%d\n",m);
int a=dp[x][m];
int b=dp[y-_pow[m]+1][m];
return dep[a]<dep[b]?a:b;
}
int lca(int u,int v)
{
int x=first[u];
int y=first[v];
if(x>y) swap(x,y);
int index=rmq(x,y);
//printf("index==%d %d %d\n",x,y,index);
return node[index];
}
void path(int &index,int s,int t)
{
while(s!=t)
{
p[index++]=val[s];
s=fa[s];
}
p[index++]=val[t];
}
bool cmp(int x,int y)
{
return x>y;
}
void solve(int k,int u,int v)
{
int lc=lca(u,v);
//printf("%d %d %d\n",u,v,lc);
int t=0;
path(t,u,lc);
path(t,v,lc);
t--;//公共祖先被記錄兩次
if(k>t)
{
printf("invalid request!\n");
return;
}
sort(p,p+t,cmp);
printf("%d\n",p[k-1]);
}
int main()
{
for(int i=0;i<20;i++) _pow[i]=1<<i;
int n,q;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
int tot=0;
memset(vis,0,sizeof(vis));
dfs(tot,1,1,-1);
//分別表示dfs的長度,跑到的節點,深度,該點的父親節點(-1表示沒有父親節點)
// for(int i=1;i<=tot;i++) printf("%d ",node[i]);
// cout<<endl;
// for(int i=1;i<=tot;i++) printf("%d ",dep[i]);
// cout<<endl;
// for(int i=1;i<=n;i++) printf("%d ",first[i]);
// cout<<endl;
// for(int i=1;i<=n;i++) printf("%d ",fa[i]);
// cout<<endl;
// printf("tot==%d\n",tot);
rmq_init(tot);
while(q--)
{
int op,u,v;
scanf("%d%d%d",&op,&u,&v);
if(op==0) val[u]=v;
else solve(op,u,v);
}
return 0;
}