題目鏈接:https://vjudge.net/problem/SPOJ-COT2
教學博客:https://www.cnblogs.com/zwfymqz/p/9223425.html
題目大意:查詢鏈上點權不同數的個數。
ps:如果是邊權的話,那麼u,v對應的區間是[st[u]+1,st[v]],無需lca。
#include <bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
#define all(x) (x).begin(),(x).end()
using namespace std;
const int N = 4e4+1; //點數
const int M = 1e5+1;
int st[N],cnt,euler[N*2],deep[N],ed[N],fa[N][20];
vector<int> nxt[N];
void dfs(int u,int f) {
deep[u] = deep[f]+1;
fa[u][0] = f;
euler[++cnt] = u;
st[u] = cnt;
for(auto v:nxt[u]) {
if(v==f) continue;
dfs(v,u);
}
euler[++cnt] = u;
ed[u] = cnt;
}
void init(int n) {
for(int j = 1; j <= 19; j++)
for(int i = 1; i <= n; i++)
fa[i][j] = fa[fa[i][j-1]][j-1];
}
int lca(int u, int v){
if(deep[u] < deep[v]) swap(u, v);
int diff = deep[u] - deep[v];
for(int i = 19; i >= 0; i--) if(diff>>i&1) u = fa[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) if(fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
return fa[u][0];
}
//-------------------------------------------------------
int ans[M],block,Ans,num[M],s[N],bid[2*N],bnum;
bool vis[M];
struct node{
int l,r,lca,id;
bool operator < (const node &cmp) const{ //常數優化
return (bid[l]^bid[cmp.l])?(bid[l]<bid[cmp.l]):((bid[l]&1)?r<cmp.r:r>cmp.r);
}
}q[M];
void add(int x){
if(++num[x]==1) Ans++;
}
void del(int x){
if(--num[x]==0) Ans--;
}
void Add(int p) {
if(vis[p]) del(s[p]); //判斷節點p出現奇偶次,奇數在鏈上,偶數不在鏈上
else add(s[p]);
vis[p] ^= 1;
}
//-------------------------------------------------------
int n,m,u,v;
vector<int>k;
int main() {
ios::sync_with_stdio(0);
cin>>n>>m;
rep(i, 1, n) cin>>s[i],k.pb(s[i]);
sort(all(k)); //點權離散化
k.erase(unique(all(k)),k.end());
rep(i, 1, n) s[i] = lower_bound(all(k),s[i])-k.begin()+1;
rep(i, 1, n-1) {
cin>>u>>v;
nxt[u].pb(v);
nxt[v].pb(u);
}
dfs(1,0);
init(n);
block = sqrt(2*n);
bnum = ceil(double(2*n)/block);
rep(i, 1, bnum) rep(j, block*(i-1)+1, min(2*n,i*block)) bid[j] = i; //常數優化
rep(i, 1, m) {
cin>>u>>v;
if(st[u]>st[v]) swap(u,v);
int _lca = lca(u,v);
if(_lca==u) q[i] = (node){st[u],st[v],0,i};
else q[i] = (node){ed[u],st[v],_lca,i};
}
sort(q+1,q+m+1);
int l = 1,r = 0;
rep(i, 1, m){
while(r < q[i].r) Add(euler[++r]);
while(r > q[i].r) Add(euler[r--]);
while(l < q[i].l) Add(euler[l++]);
while(l > q[i].l) Add(euler[--l]);
if(q[i].lca) Add(q[i].lca); //特判lca
ans[q[i].id] = Ans;
if(q[i].lca) Add(q[i].lca);
}
rep(i, 1, m) printf("%d\n",ans[i]);
return 0;
}