題目大意:n個節點的樹,m次操作,每次將白點變黑,將黑點變白,或詢問最遠黑點對的距離。
若無修改,可直接樹形dp或點分求即可,加上修改的話就要用到點分樹了(orz括號序列的做法)。。。考慮沒修改時點分治的做法爲對每個root,最遠點對即爲不同子樹上最遠和次遠黑點的距離之和,所以對於點分樹的每個結點 X 用個堆S維護X的每棵子樹最遠黑點距離,以及一個堆T維護 X 所在子樹的每個黑點與上一層節點 fa[x] 的距離。然後在加個堆ans維護每個節點的答案即可。。
這裏用的multiset來代替的堆,寫法較容易,但時間也慢得飛起。。
#include<iostream>
#include<stdio.h>
#include<set>
#include<vector>
#include<math.h>
using namespace std;
vector<int>g[100010];
int son[100010],siz[100010],ma[100010],vis[100010],root,A[100010];
int dis[100010][20],dep[100010],fa[100010][20];
void dfsroot(int u,int f,int sum){
int i,v;
son[u]=0;
siz[u]=1;
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=f&&vis[v]==0){
dfsroot(v,u,sum);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
ma[u]=max(siz[son[u]],sum-siz[u]);
if(root==0||ma[u]<ma[root]) root=u;
}
void dfsdis(int u,int f,int k,int d){ //預處理出每個節點到每一層根節點的距離,下面就不需要用lca來求距離辣
int i,v;
dep[u]++;
fa[u][dep[u]]=k;
dis[u][dep[u]]=d;
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(v!=f&&vis[v]==0) dfsdis(v,u,k,d+1);
}
}
void dfs(int u){
int i,v;
vis[u]=1;
dfsdis(u,0,u,0);
for(i=0;i<g[u].size();i++){
v=g[u][i];
if(vis[v]==0){
root=0;
dfsroot(v,0,siz[v]);
dfs(root);
}
}
}
multiset<int>s[100010],t[100010],ans;
multiset<int>::iterator it;
int max1(int u){
if(t[u].size()==0) return -1;
it=t[u].end();
it--;
return *it;
}
int max2(int u){
if(s[u].size()<2) return -1;
it=s[u].end();
it--;
int a=*it;
it--;
return a+*it;
}
void update(int u){
int i,f,ff;
for(i=dep[u];i;i--){
f=fa[u][i];
ff=fa[u][i-1];
if(i==dep[u]){
if(max2(f)!=-1) ans.erase(ans.find(max2(f)));
if(A[u]) s[f].erase(0);
else s[f].insert(0);
if(max2(f)!=-1) ans.insert(max2(f));
}
if(ff){
if(max2(ff)!=-1) ans.erase(ans.find(max2(ff)));
if(max1(f)!=-1) s[ff].erase(s[ff].find(max1(f)));
if(A[u]) t[f].erase(t[f].find(dis[u][i-1]));
else t[f].insert(dis[u][i-1]);
if(max1(f)!=-1) s[ff].insert(max1(f));
if(max2(ff)!=-1) ans.insert(max2(ff));
}
}
}
int main(){
int i,n,a,b,m;
char c[10];
scanf("%d",&n);
for(i=1;i<n;i++){
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
root=0;
dfsroot(1,0,n);
dfs(root);
for(i=1;i<=n;i++){
update(i);
A[i]=1;
}
scanf("%d",&m);
while(m--){
scanf("%s",c);
if(c[0]=='C'){
scanf("%d",&a);
update(a);
A[a]=!A[a];
}
else{
if(ans.size()==0) printf("-1\n");
else{
it=ans.end();
it--;
printf("%d\n",*it);
}
}
}
return 0;
}