題面:
有一棵點數
105 的樹,每個節點有一個權值,權值範圍是[1,105] 的,問所有兩個權值互質的節點之間距離的和
關於虛樹,記下一些key point以後要是忘了可以回來看看
虛樹是保留原樹的所選定的一些節點,保留這些選定節點兩兩的
可以保證如果選定節點是
建立虛樹的方法是將選定節點按照dfs序排序之後,維護一個深度單調的棧,棧中保存的是可能和下一個節點產生父子關係的點。這些點滿足在同一條路徑上。
每次加入一個點的時候,將這個點和棧頂的點(如果棧空的話將這個點直接入棧就可以了)的lca計算出來,找到lca在棧中維護的路徑的鏈的位置去更新棧。具體見代碼。
得到虛樹之後每個邊的貢獻就是這個邊兩遍的點的個數的乘積,累加就好了。
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int maxn = 212345,maxm = 212345,max_log = 18;
const int ROOT = 1;
vector<pair<int,LL> > vedge[maxn];
vector<int> edge[maxn];
LL siz[maxn];
int is[maxn];
void dffs(int st,int fa){
siz[st] = is[st];
for(auto x : vedge[st]){
if(x.first != fa){
dffs(x.first,st);
siz[st] += siz[x.first];
}
}
}
LL dfss(int st,int fa,int n){
LL ret = 0;
for(auto it : vedge[st]){
int x = it.first; LL v = it.second;
if(x != fa){
ret += dfss(x,st,n);
ret += v * siz[x] * (n-siz[x]);
}
}
return ret;
}
LL dpfs(int n){
dffs(0,-1);
return dfss(0,-1,siz[0]);
}
int dfn[maxn],_cnt;
int deep[maxn],fa[maxn][max_log];
void dfs(int st,int Fa,int Deep=1){
dfn[st] = _cnt++;
for(int i=1;i<max_log;i++)
fa[st][i] = -1;
fa[st][0] = Fa,deep[st] = Deep;
for(auto x : edge[st]){
if(x == Fa) continue;
dfs(x,st,Deep+1);
}
}
void init(int n){
_cnt = 1;
memset(fa,-1,sizeof(fa));
dfs(ROOT,-1);
for (int j = 1;j < max_log;j++){
for (int i = 1;i <= n;i++){
if (fa[i][j-1] != -1){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
}
int Lca(int x,int y){
LL ix,iy;
ix = iy = 0;
if (deep[x] < deep[y]) swap(x,y);
for (int i = max_log-1;i >= 0;i--){
if (deep[fa[x][i]] >= deep[y]){
x = fa[x][i];
}
}
if (x == y) return x;
for (int i = max_log-1;i >= 0;i--){
if (fa[x][i] != fa[y][i]){
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int val[maxn],cid[maxn],sid[maxn];
int cnt;
int tcid(int v){
if(sid[v] != -1) return sid[v];
sid[v] = cnt, cid[cnt] = v;
return cnt++;
}
void Link(int st,int ed,int v = -1){
if(v == -1) v = abs(deep[st] - deep[ed]);
if(v==0) return;
st = tcid(st),ed = tcid(ed);
vedge[st].push_back(make_pair(ed,v));
vedge[ed].push_back(make_pair(st,v));
}
struct Stack{
int sta[maxn],_cnt;
int last;
void init(){ _cnt = 0, last = -1; }
bool empty(){ return _cnt==0; }
int top(){ return sta[_cnt-1]; }
void push(int x){
if(last != -1)
Link(last,x);
sta[_cnt++] = x;
last = -1;
}
void pop(){
if(last != -1)
Link(this->top(),last);
last = this->top();
_cnt--;
}
}S;
vector<int> hav[maxn];
LL bvtree(int k){
sort(hav[k].begin(),hav[k].end(),[](int x,int y){return dfn[x] < dfn[y];});
cnt = 0;
S.init();
for(auto x : hav[k]){
if(S.empty()){
S.push(x);
continue;
}
int lca = Lca(S.top(),x);
while(S.empty()==false && deep[S.top()] >= deep[lca])
S.pop();
if(S.empty() || deep[S.top()] < deep[lca])
S.push(lca);
if(lca != x)
S.push(x);
}
while(S.empty()==false) S.pop();
for(int i=0;i<cnt;i++)
is[i] = val[cid[i]] % k == 0;
LL ret = dpfs(cnt);
for(int i=0;i<cnt;i++){
vedge[i].clear();
sid[cid[i]] = -1;
cid[i] = -1;
}
return ret;
}
vector<int> fac[maxn];
int mu[maxn];
int main(){
int n;
for(int i=1;i<maxn;i++){
for(int j=i;j<maxn;j+=i)
fac[j].push_back(i);
}
memset(mu,0,sizeof(mu));
mu[1] = 1;
for(int i=1;i<maxn;i++){
for(int j=2*i;j<maxn;j+=i)
mu[j] -= mu[i];
}
while(~scanf("%d",&n)){
for(int i=0;i<=n;i++)
edge[i].clear(),hav[i].clear(),vedge[i].clear();
int v;
for(int i=1;i<=n;i++){
scanf("%d",&v);
val[i] = v;
for(auto x : fac[v]){
if(mu[x]) hav[x].push_back(i);
}
}
int l,r;
for(int i=1;i<n;i++){
scanf("%d %d",&l,&r);
edge[l].push_back(r);
edge[r].push_back(l);
}
init(n);
memset(sid,-1,sizeof(sid));
LL ans = 0;
for(int i=1;i<maxn;i++){
if(hav[i].empty() == false)
ans += mu[i] * bvtree(i);
}
printf("%lld\n",ans);
}
return 0;
}
附上一個比賽時的弱雞代碼
#include<bits/stdc++.h>
using namespace std;
const int maxn = 112345,max_log = 18;
vector<int> edge[maxn];
vector<pair<int,int> >vedge[maxn];
int dis[maxn];
void pdfs(int st,int fa){
for(auto it : vedge[st]){
if(it.first == fa) continue;
dis[it.first] = dis[st] + it.second;
pdfs(it.first,st);
}
}
pair<int,int> getPair(int n,int st){
dis[st] = 0;
pdfs(st,-1);
st = max_element(dis,dis+n) - dis;
dis[st] = 0;
pdfs(st,-1);
return make_pair(st,max_element(dis,dis+n) - dis);
}
//Need check
const int ROOT = 1;
int deep[maxn],dfn[maxn],fa[maxn][max_log];
int _cnt;
void dfs(int st,int Fa,int Deep=1){
for(int i=1;i<max_log;i++)
fa[st][i] = -1;
dfn[st] = _cnt++;
fa[st][0] = Fa,deep[st] = Deep;
for(auto x : edge[st]){
if(x == Fa) continue;
dfs(x,st,Deep+1);
}
}
void init(int n){
memset(fa,-1,sizeof(fa));
_cnt = 1;
dfs(ROOT,-1);
for (int j = 1;j < max_log;j++){
for (int i = 1;i <= n;i++){
if (fa[i][j-1] != -1){
fa[i][j] = fa[fa[i][j-1]][j-1];
}
}
}
}
int Lca(int x,int y){
if (deep[x] < deep[y]) swap(x,y);
for (int i = max_log-1;i >= 0;i--){
if (deep[fa[x][i]] >= deep[y]){
x = fa[x][i];
}
}
if (x == y) return x;
for (int i = max_log-1;i >= 0;i--){
if (fa[x][i] != fa[y][i]){
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int cid[maxn],sid[maxn]; //assert sid is -1 init
int vnt;
int tcid(int v){
if(sid[v] != -1) return sid[v];
sid[v] = vnt, cid[vnt] = v;
return vnt++;
}
void Link(int st,int ed){
if(ed == -1 || st == ed) return;
int v = abs(deep[st]-deep[ed]);
st = tcid(st),ed = tcid(ed);
vedge[st].push_back(make_pair(ed,v));
vedge[ed].push_back(make_pair(st,v));
}
stack<int> S;
vector<int> inp;
pair<int,int> bvtree(){
if(inp.size() ==1) return make_pair(inp[0],inp[0]);
sort(inp.begin(),inp.end(),[](int x,int y){return dfn[x] < dfn[y];});
vnt = 0;
int last = -1;
while(S.empty()==false) S.pop();
for(auto x : inp){
if(S.empty()){
Link(x,last) ,last = -1 ,S.push(x) ;
continue;
}
int lca = Lca(S.top(),x);
while(!S.empty() && deep[S.top()] >= deep[lca])
Link(S.top(),last), last = S.top(), S.pop();
if( S.empty() || deep[S.top()] < deep[lca])
Link(lca ,last), last = -1 ,S.push(lca);
if(lca != x) S.push(x) ;
}
while(!S.empty())
Link(S.top(),last), last = S.top(), S.pop();
auto ret = getPair(vnt,0);
ret.first = cid[ret.first] , ret.second = cid[ret.second];
for(int i=0;i<vnt;i++)
vedge[i].clear(), sid[cid[i]] = -1, cid[i] = -1;
return ret;
}
map<string,vector<int> >store;
map<string,pair<int,int> >lis;
int dist(int x,int y){
return deep[x] + deep[y] - deep[Lca(x,y)] * 2 + 1;
}
char name[10];
int query(){
scanf("%s",name);
if(lis.find(name) == lis.end()){
scanf("%s",name);
return -1;
}
auto a = lis[name];
scanf("%s",name);
if(lis.find(name) == lis.end()){
return -1;
}
auto b = lis[name];
return max( max(dist(a.first,b.first) , dist(a.first,b.second)),
max(dist(a.second,b.first), dist(a.second,b.second)));
}
int main(){
int n,m;
while(~scanf("%d %d",&n,&m)){
store.clear(),lis.clear();
for(int i=1;i<=n;i++){
scanf("%s",name);
store[name].push_back(i);
edge[i].clear();
}
int l,r;
for(int i=1;i<n;i++){
scanf("%d %d",&l,&r);
edge[l].push_back(r);
edge[r].push_back(l);
}
init(n);
memset(sid,-1,sizeof(sid));
for(auto it : store){
inp = it.second;
lis[it.first] = bvtree();
}
while(m--){
printf("%d\n",query());
}
}
return 0;
}