題目鏈接:http://acm.hdu.edu.cn/showproblem.php?pid=4670
算法:基於點的樹鏈分治
思路:
//樹基於點的分治算法,可以參見國家集訓隊論文:2009年漆子超《分治算法在樹的路徑問題中的應用》
/*
000111222
+ 012012012
----------------------
= 012120201
*/
//Mul[u,v] 等於點u到點v的乘積中,各個素數的個數對3取餘後的一個數列
//對於每一個Mul[u, S], 求Mul(S, v]+Mul[u,S]=0的個數(S爲分治點)
//用map記錄下所有Mul[S,u]的狀態(用longlong保存3進制數),再用一個map保存取出分治點S以後(S,v]的狀態,對於第一個map中的每一個狀態,求在第二個map中是否存在它的互補狀態
注意:
//第一次遇到手動擴棧的題,檢查了好長時間
//素數有可能超過int
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<map>
using namespace std;
#define LL __int64
const int M = 50010;
LL pri[35];//保存素數
int Num;//素數的個數
int num[M][30];//每個數的化簡
struct Edge {
int v, next;
} edge[M << 1];
int head[M], E;//鄰接表參數
LL ans = 0;//保存最終結果
int mmax, root, size[M], mx[M], vis[M];//求重點的參數
int tmp_cnt;//計算類似8,27,125這樣單獨成區間的區間個數
int dis[M][30];//保存區間[S,i]中第j個素數的個數dis[i][j]
map<LL, LL> er, mp;//er是計算3進制下需要匹配的數的個數,mp是計算分治點S到目標點Q,Mul[S,Q]在3進制下得到的數
map<LL, LL>::iterator it;
void init() {
ans = E = 0;
memset(vis, 0, sizeof(vis));
memset(head, -1, sizeof(head));
memset(num, 0, sizeof(num));
}
void add_edge(int s, int v) {
edge[E].v = v;
edge[E].next = head[s];
head[s] = E++;
}
void dfs_uu(int u, int fa) {
size[u] = 1;
mx[u] = 0;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v])continue;
dfs_uu(v, u);
size[u] += size[v];
if (size[v] > mx[u]) mx[u] = size[v];
}
}
void dfs_u(int u, int fa, int r) {
if (size[r] - size[u] > mx[u]) mx[u] = size[r] - size[u];
if (mmax > mx[u]) { mmax = mx[u]; root = u;}
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v]) continue;
dfs_u(v, u, r);
}
}//以上兩個函數求重點
void dfs_dis(int u, int fa, int r) {
LL tmp = 0, sum = 0, sum_nee = 0;
for (int i = 0; i < Num; i++) {
dis[u][i] = dis[fa][i] + num[u][i];
if (dis[u][i] > 2) dis[u][i] -= 3;
tmp = 3 - dis[u][i] + dis[r][i]; if (tmp >= 3) tmp -= 3;
sum = sum * 3 + tmp;
sum_nee = sum_nee * 3 + dis[u][i];
}
if (sum == sum_nee) tmp_cnt++;
er[sum]++;
mp[sum_nee]++;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (v == fa || vis[v])
continue;
dfs_dis(v, u, r);
}
}
LL cala(int u, int fa) {
mp.clear(); er.clear();
LL tt = 0;
tmp_cnt = 0;
if (fa != 0) for (int i = 0; i < Num; i++) dis[fa][i] = num[fa][i];
dfs_dis(u, fa, fa == 0 ? u : fa);
for (it = er.begin(); it != er.end(); it++) {
if (mp.find(it->first) != mp.end())
tt += it->second * mp[it->first];
}
return (tt - tmp_cnt) / 2 + tmp_cnt;
}//以上兩個【核心】函數求分治區間中滿足條件的點對~
void solve(int u) {
int rt = 0;
mmax = 123456;
root = u;
dfs_uu(u, -1);
dfs_u(u, -1, u);
ans += cala(root, 0);
vis[root] = 1;
rt = root;
for (int i = head[rt]; i != -1; i = edge[i].next) {
int v = edge[i].v;
if (vis[v]) continue;
ans -= cala(v, rt);
solve(v);
}
}
int main() {
int n;
while (scanf("%d", &n) != EOF) {
int i, j;
LL val;
init();
scanf("%d", &Num);
for (i = 0; i < Num; i++) {
scanf("%I64d", &pri[i]);
}
for (i = 1; i <= n; i++) {
scanf("%I64d", &val);
for (j = 0; j < Num; j++) {
while (val % pri[j] == 0) {
num[i][j]++;
val = val / pri[j];
if (num[i][j] > 2)
num[i][j] -= 3;
}
}
}
int a, b;
for (i = 0; i < n - 1; i++) {
scanf("%d%d", &a, &b);
add_edge(a, b);
add_edge(b, a);
}
solve(1);
printf("%I64d\n", ans);
}
return 0;
}