問題描述
輸入格式
第一行包含兩個整數 N,Q,表示星球的數量和操作的數量。星球從 1 開始編號。
接下來的 Q 行,每行是如下兩種格式之一:
A x y 表示在 x和 y之間連一條邊。保證之前 x 和 y 是不聯通的。
Q x y 表示詢問 (x,y)這條邊上的負載。保證 x和 y之間有一條邊。
輸出格式
對每個查詢操作,輸出被查詢的邊的負載。
樣例輸入1
8 6
A 2 3
A 3 4
A 3 8
A 8 7
A 6 5
Q 3 8
樣例輸出1
6
樣例輸入2
10 20
A 4 10
A 1 6
A 10 1
Q 4 10
A 9 8
A 7 9
Q 10 1
A 2 5
Q 9 8
Q 2 5
Q 4 10
Q 9 8
Q 2 5
A 8 3
Q 7 9
Q 2 5
A 5 7
Q 8 3
A 3 4
Q 7 9
樣例輸出2
3
4
2
1
3
2
1
3
1
5
21
數據範圍
對於所有數據,1≤N,Q≤100000
題解
顯然,根據乘法原理,一條邊的“負載”等於邊兩端的連通點數的乘積。如樣例一,【3,8】中,與3號點連通的點(在不走【3,8】能到達的點,包括它本身)有3個,與8號點連通的點有2個,所以“負載”是6 。
所以,求“負載”的時候,我們可以先把一端點x弄成根,設以x爲根的子樹節點數爲Size[x],以y爲根的子樹節點數爲Size[y],答案即爲Size[x]*(Size[x]-Size[y])
維護動態樹,LCT吧。。。
但是注意!當在維護Splay的時候,不要只記錄Splay樹中的節點個數。我是直接開了個v[]數組,記錄以i 點爲根且到i 點路徑不經過i 點的preferred son的點數。但這個維護起來又有點噁心。。。
還有,1e5 顯然是會爆int 的,注意使用long long。
代碼
#include <cstdio>
#include <iostream>
#include <ctime>
#include <stack>
#include <cstdlib>
#include <algorithm>
#define ll long long
using namespace std;
const ll Q=100005;
ll ls[Q],rs[Q],si[Q],n,f[Q],an[Q],v[Q],lazy[Q];
void lx(ll x)
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)ls[z]=x;
else rs[z]=x;
f[x]=z;
swap(an[x],an[y]);
rs[y]=ls[x];
f[rs[y]]=y;
f[y]=x;
ls[x]=y;
si[x]=si[y];
si[y]=si[ls[y]]+si[rs[y]]+v[y];
}
void rx(ll x)
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)ls[z]=x;
else rs[z]=x;
f[x]=z;
swap(an[x],an[y]);
ls[y]=rs[x];
f[ls[y]]=y;
f[y]=x;
rs[x]=y;
si[x]=si[y];
si[y]=si[ls[y]]+si[rs[y]]+v[y];
}
void pd(ll x)
{
swap(ls[x],rs[x]);
if(ls[x])lazy[ls[x]]^=1;
if(rs[x])lazy[rs[x]]^=1;
lazy[x]=0;
}
int ding=0,st[Q];
void splay(ll x)
{
for(ll now=x;now;now=f[now])st[++ding]=now;
while(ding)
{
if(lazy[st[ding]])pd(st[ding]);
--ding;
}
while(f[x])
{
ll y=f[x],z=f[y];
if(z)if(ls[z]==y)
if(ls[y]==x)rx(y),rx(x);
else lx(x),rx(x);
else if(rs[y]==x)lx(y),lx(x);
else rx(x),lx(x);
else if(ls[y]==x)rx(x);
else lx(x);
}
}
void ac(ll x)
{
ll y=0;
while(x)
{
splay(x);
if(rs[x])
{
v[x]+=si[rs[x]];
f[rs[x]]=0;
an[rs[x]]=x;
}
rs[x]=y;
v[x]-=si[y];
if(y)f[y]=x;
y=x;
x=an[x];
}
}
void mr(ll x)
{
ac(x);
splay(x);
lazy[x]^=1;
}
int main()
{
char o[15];
ll i,x,y;
scanf("%lld%lld",&n,&i);
for(x=1;x<=n;x++)
si[x]=v[x]=1;
while(i--)
{
scanf("%s%lld%lld",o,&x,&y);
if(o[0]=='A'){
mr(x);
an[x]=y;
ll temp=si[x];
while(an[x])
{
splay(an[x]);
si[an[x]]+=temp;
v[an[x]]+=temp;
x=an[x];
}
ac(y);
}
else{
mr(x);
ac(y);
splay(y);
printf("%lld\n",si[x]*(si[y]-si[x]));
}
}
return 0;
}