SPOJ COT(主席樹 + LCA)

題目鏈接

主席樹簡單來說就是線段樹+前綴和,每個節點的線段樹存的是他以及之前所有與他相關節點的線段樹的信息和,現在詢問u, v, 設u,v的lca是root, root的父節點是fa[root], 那麼左節點的數量就是tVal[ lt[ t[u] ] ] + tVal[ lt[ t[v] ] ] - tVal[ lt[ t[ root ] ] ] - tVal[ lt[ t[fa[root] ] ] ]。
一開始我減了兩次root的值,這樣root的點就減了兩次,就沒了,所以就少了root這個點的值,以後寫lca的時候要注意。

#include <bits/stdc++.h>

using namespace std;

const int maxn = 2*(int)1e6+100;
int n, m, a[maxn], b[maxn], len;
int tot = 0, tVal[maxn*40], t[maxn*40], lt[maxn*40], rt[maxn*40], fa[maxn];
vector< vector<int> > G(maxn);
struct LCA
{
	#define type int
	struct node{int to;type w;node(){}node(int _to,type _w):to(_to),w(_w){}};
	type dist[maxn];
	int path[maxn],dep[maxn],loc[maxn],len[maxn],LOG[maxn],all,n;
	int dp[25][maxn], point[25][maxn];  //2^20 == 1e6  2^25 == 3e7
	vector<node> G[maxn];
	void dfs(int u, int now) {
		path[++all] = u;
		loc[u] = all;
		dep[all] = now;
		for (node cur : G[u]) {
            int v = cur.to;
			if (loc[v]) continue;
			len[v] = now+1;
			dist[v] = dist[u]+cur.w;
			dfs(v, now+1);
			path[++all] = u;
			dep[all] = now;
		}
	}
	void initRMQ(int n)
	{
		LOG[0] = -1;
		for (int i = 1; i <= all; ++i) {
			dp[0][i] = dep[i];
			point[0][i] = path[i];
			LOG[i] = ((i&(i-1)) == 0 ? LOG[i-1]+1 : LOG[i-1]);
		}
        for (int i = 1; (1<<i) <= all; ++i) {
            for (int j = 1; j+(1<<i)-1 <= all; ++j) {
              if (dp[i-1][j] < dp[i-1][j+(1<<(i-1))]) {
                dp[i][j] = dp[i-1][j];
                point[i][j] = point[i-1][j];
              } else {
                dp[i][j] = dp[i-1][j+(1<<(i-1))];
                point[i][j] = point[i-1][j+(1<<(i-1))];
              }
            }
        }
	}
	int queryLCA(int l,int r)
	{
		l = loc[l]; r = loc[r];
		if(l>r) swap(l,r);
		int k = LOG[r-l+1];
		/*
		貌似下面這種寫法對於某些數據情況更快,對於某些數據也更慢- -
		記得把上面預處理的LOG刪了
		 P 3379
		int k=0;
		while((1<<k)<=r-l+1) k++;
		k--;
		*/
		if(dp[k][l] < dp[k][r-(1<<k)+1]) return point[k][l];
		else return point[k][r-(1<<k)+1];
	}

	type getDist(int a,int b){return dist[a]+dist[b]-2*dist[queryLCA(a,b)];}
	int getLen(int a,int b){return len[a]+len[b]-2*len[queryLCA(a,b)];}
	void init(int _n)
	{
		n = _n;
		all = 0;
		for(int i = 0;i <= n; i++)
		{
			loc[i] = 0;
			dist[i] = 0;
			len[i] = 0;
			G[i].clear();
		}
	}
	void addEdge(int a,int b,type w=1)
	{
		G[a].emplace_back(node(b,w));
		G[b].emplace_back(node(a,w));
	}
	void solve(int root)
	{
		dfs(root, 1);
		initRMQ(all);
	}
	#undef type
}lca;

int build(int l, int r) {
    int node = ++tot;
    int mid = (l + r) >> 1;
    if (l < r) {
        lt[node] = build(l, mid);
        rt[node] = build(mid + 1, r);
    }
    return node;
}
int update(int l, int r, int par, int p) {
    int node = ++tot;
    lt[node] = lt[par]; rt[node] = rt[par]; tVal[node] = tVal[par] + 1;
    int mid = (l + r) >> 1;
    if (l < r) {
        if (p <= mid) lt[node] = update(l, mid, lt[par], p);
        else rt[node] = update(mid + 1, r, rt[par], p);
    }
    return node;
}
int query(int l, int r, int parpar, int par, int ql, int qr, int k) {
    if (l == r) return l;
    int mid = (l + r) >> 1;
    int sum = 0;
    sum = tVal[lt[ql]] + tVal[lt[qr]] - tVal[lt[par]] - tVal[lt[parpar]];
    if (sum >= k) {
        return query(l, mid, lt[parpar], lt[par], lt[ql], lt[qr], k);
    }
    else return query(mid + 1, r, rt[parpar], rt[par], rt[ql], rt[qr], k - sum);    //k - sum !
}
void dfs(int x, int par) {
    fa[x] = par;
    int p = lower_bound(b + 1, b + 1 + len, a[x]) - b;
    t[x] = update(1, len, t[par], p);
    for (int to : G[x]) {
        if (to == par) continue;
        dfs(to, x);
    }
}
int read() {
    int ans = 0, f = 1; char c = getchar();
    for (;c < '0' | c > '9'; c = getchar()) if (c == '-') f = -1;
    for (;c >= '0' && c <= '9'; c = getchar()) ans = ans * 10 + c - '0';
    return ans * f;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.precision(10);
    cout << fixed;
#ifdef LOCAL_DEFINE
    freopen("input.txt", "r", stdin);
#endif

    memset(tVal, 0, sizeof(tVal));
    n = read(); m = read();
    lca.init(n);
    for (int i = 1; i <= n; ++i) {
        a[i] = read();
        b[i] = a[i];
    }
    sort(b + 1, b + 1 + n);
    len = unique(b + 1, b + 1 + n) - b - 1;
    for (int i = 1; i <= n - 1; ++i) {
        int u, v;
        u = read(); v = read();
        G[u].emplace_back(v);
        G[v].emplace_back(u);
        lca.addEdge(u, v, 1);
    }
    lca.solve(1);
    t[0] = build(1, len);
    dfs(1, 0);
    for (int i = 1; i <= m; ++i) {
        int u, v, k;
        u = read(); v = read(); k = read();
        int temp = lca.queryLCA(u, v);
        int pos = query(1, len, t[fa[temp]], t[temp], t[u], t[v], k);
        cout << b[pos] << '\n';
    }

#ifdef LOCAL_DEFINE
    cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
    return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章