【模板】樹上莫隊

模板題

#include<bits/stdc++.h>
using namespace std;
#define N 100005
vector<int>g[N];
int st[N],nd[N],a[N];
int pos[N],dfn[N],cnt,sum[N];
int fa[N][20],dep[N],ans[N],n,m,an;
bool f[N];
struct mzls
{
	int l,r,id;
	bool p;
}a1[N];
inline bool cmp(mzls x,mzls y)
{
	if(pos[x.l]!=pos[y.l])
		return pos[x.l]<pos[y.l];
	return x.r<y.r;
}
map<int,int>mp;
inline void dfs(int x)
{
	dfn[++cnt]=x;
	st[x]=cnt;
	dep[x]=dep[fa[x][0]]+1;
	for(int i=1;i<20;i++)
		fa[x][i]=fa[fa[x][i-1]][i-1];
	int l1=g[x].size();
	for(int i=0;i<l1;i++)
	{
		if(g[x][i]==fa[x][0])
			continue;
		fa[g[x][i]][0]=x;
		dfs(g[x][i]);
	}
	dfn[++cnt]=x;
	nd[x]=cnt;
}
inline int LCA(int x,int y)
{
	if(dep[x]<dep[y])
		swap(x,y);
	for(int i=19;i>=0;i--)
		if(dep[fa[x][i]]>=dep[y])
			x=fa[x][i];
	if(x==y)
		return x;
	for(int i=19;i>=0;i--)
		if(fa[x][i]!=fa[y][i])
		{
			x=fa[x][i];
			y=fa[y][i];
		}
	return fa[x][0];
}
inline void modify(int x)
{
	if(f[x])
	{
		sum[a[x]]--;
		if(sum[a[x]]==0)
			an--;
		f[x]=0;
	}
	else
	{
		sum[a[x]]++;
		if(sum[a[x]]==1)
			an++;
		f[x]=1;
	}
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&a[i]);
		if(mp[a[i]]==0)
			mp[a[i]]=++cnt;
		a[i]=mp[a[i]];
	}
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		g[x].push_back(y);
		g[y].push_back(x);
	}
	dfs(1);
	int d=sqrt(cnt);
	for(int i=1;i<=cnt;i++)
		pos[i]=(i-1)/d;
	for(int i=1;i<=m;i++)
	{
		int x,y;
		scanf("%d%d",&a1[i].l,&a1[i].r);
		if(st[a1[i].l]>st[a1[i].r])
			swap(a1[i].l,a1[i].r);
		if(nd[a1[i].l]>st[a1[i].r])
		{
			x=nd[a1[i].r];
			y=nd[a1[i].l];
			a1[i].p=0;
		}
		else
		{
			x=nd[a1[i].l];
			y=st[a1[i].r];
			a1[i].p=1;
		}
		a1[i].l=x;
		a1[i].r=y;
		a1[i].id=i;
	}
	sort(a1+1,a1+m+1,cmp);
	int l1=1,r1=0;
	for(int i=1;i<=m;i++)
	{
		while(l1>a1[i].l)
			modify(dfn[--l1]);
		while(l1<a1[i].l)
			modify(dfn[l1++]);
		while(r1<a1[i].r)
			modify(dfn[++r1]);
		while(r1>a1[i].r)
			modify(dfn[r1--]);
		int lc=LCA(dfn[l1],dfn[r1]);
		if(a1[i].p)
			modify(lc);
		ans[a1[i].id]=an;
		if(a1[i].p)
			modify(lc);
	}
	for(int i=1;i<=m;i++)
		printf("%d\n",ans[i]);
	return 0;
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章