题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4605
题目大意:有一棵n个节点的有根树。每个节点都有一个权值w,同时树中的节点,要么没有儿子节点,要么就刚好有两个儿子节点。现在要在根节点放一个权值为X的球。当球落到节点 i 的时候,球都会有如下三种情况:
1、如果X=w[i]或者 i 没有儿子节点的话,球就会停在这个节点,不继续往下落;
2、如果X < w[i] 的话,球有1/2的概率落到左儿子,有1/2的概率落到右儿子;
3、如果X > w[i] 的话,球有1/8的概率落到左儿子,有7/8的概率落到右儿子。
现在有q次询问,每次给出一个v和X,表示询问在根节点放置一个权值为X的球的话,有多大的概率落到节点v。
以的形式输出x和y,如果无法到达就直接输出0。
题目思路:对于每次查询,我们可以知道,只与根节点到节点v这条路径上的信息有关。
如果在这条路径上有一个非 v 的节点的权值等于X的话,那么就无法到达节点v了。
现在我们假设在这条路径上,比X大同时是走左儿子的节点个数为MAX_L,比X大同时是走右儿子的节点个数为MAX_R,比X小同时是走左儿子的节点个数为MIN_L,比X小同时是走右儿子的节点个数为MIN_R。
那么最终的答案x = MAX_R,y = MIN_L + MIN_R + MAX_L * 3 + MAX_R * 3。
现在我们来考虑要如何维护这些信息。由于每次的查询只和一条链上的信息有关,所以我们可以考虑将所有的查询都离线下来。只做一遍dfs来维护这些信息,因为在从根节点开始的dfs的过程中,遍历到节点u的话,正好是维护了从根节点到节点u的所有信息。
而我们正好可以用两个树状数组来维护这些信息,一棵维护走左儿子的节点信息,一棵维护走右儿子的节点信息。树状数组中维护的是权值为 i 的节点的个数,那么就可以通过区间查询知道,权值大于 X 的节点有多少个了。(由于权值最大有1e9,所以还得离散化一下)。
这样就可以通过离线操作求出最终的结果了。
具体实现看代码:
#include <bits/stdc++.h>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define all(v) v.begin(),v.end()
#define IOS ios::sync_with_stdio(false)
#define FIN freopen("in.txt","r",stdin)
#define fuck(x) cout<<"["<<#x<<" "<<x<<"]\n"
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int MX = 1e5 + 7;
const int mod = 998244353;
int n, m, q, N, _;
int val[MX];
int ls[MX], rs[MX];
vector<int> has;
vector<pii>que[MX];
pii ans[MX];
int fa[MX], root;
int get_id(int x) {
return lower_bound(all(has), x) - has.begin() + 1;
}
struct BIT {
int n;
int a[MX * 2];
int lowbit(int x) {return x & -x;}
void init(int _n) {
n = _n;
for (int i = 0; i <= n; i++) a[i] = 0;
}
void add(int x, int d) {
for (; x <= n; x += lowbit(x))
a[x] += d;
}
int sum(int x) {
int res = 0;
for (; x; x -= lowbit(x))
res += a[x];
return res;
}
} TL, TR;
void dfs(int u) {
for (auto now : que[u]) {
if (u == root) {
ans[now.se] = MP(0, 0);
continue;
}
int w = get_id(now.fi);
int a = TL.sum(N), b = TL.sum(w), c = TL.sum(w - 1);
a -= b; b -= c;
int a2 = TR.sum(N), b2 = TR.sum(w), c2 = TR.sum(w - 1);
a2 -= b2; b2 -= c2;
if (b || b2) ans[now.se] = MP(-1, -1);
else ans[now.se] = MP(c2, a + c * 3 + a2 + c2 * 3);
}
if (ls[u]) {
int w = get_id(val[u]);
TL.add(w, 1);
dfs(ls[u]);
TL.add(w, -1);
TR.add(w, 1);
dfs(rs[u]);
TR.add(w, -1);
}
}
int main() {
// FIN;
for (scanf("%d", &_); _; _--) {
scanf("%d", &n);
has.clear();
for (int i = 1; i <= n; i++) {
scanf("%d", &val[i]);
que[i].clear();
fa[i] = ls[i] = rs[i] = 0;
has.pb(val[i]);
}
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
int u, a, b;
scanf("%d%d%d", &u, &a, &b);
fa[a] = fa[b] = u;
ls[u] = a; rs[u] = b;
}
scanf("%d", &q);
for (int i = 1; i <= q; i++) {
int v, X;
scanf("%d%d", &v, &X);
que[v].pb(MP(X, i));
has.pb(X);
}
sort(all(has));
has.erase(unique(all(has)), has.end());
N = has.size();
TL.init(N); TR.init(N);
for (int i = 1; i <= n; i++) {
if (fa[i] == 0) {
root = i;
dfs(i);
break;
}
}
for (int i = 1; i <= q; i++) {
if (ans[i].fi == -1) printf("0\n");
else
printf("%d %d\n", ans[i].fi, ans[i].se);
}
}
return 0;
}