spoj GSS7(樹鏈剖分 + 線段樹區間合併)

題目連接


#include <cstdlib>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <iostream>
#include <sstream>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <fstream>
#include <numeric>
#include <iomanip>
#include <bitset>
#include <list>
#include <stdexcept>
#include <functional>
#include <utility>
#include <ctime>

using namespace std;

inline int readint() {
        char c = getchar();
        while (!isdigit(c)) c = getchar();
        int x = 0;
        while (isdigit(c)) {
                x = x * 10 + c - '0';
                c = getchar();
        }
        return x;
}

inline long long readlong() {
        char c = getchar();
        while (!isdigit(c)) c = getchar();
        long long x = 0;
        while (isdigit(c)) {
                x = x * 10 + c - '0';
                c = getchar();
        }
        return x;
}

#define FOR(i, n) for (int i = 0; i < (int)(n); i++)
#define REP(i, a, b) for (int i = (int)(a); i <= (int)(b); i++)
#define CIR(i, a, b) for (int i = (int)(b); i >= (int)(a); i--)
#define ADJ(i, u) for (int i = hd[u]; i != -1; i = edge[i].nxt)
#define PII pair<int, int>
#define FI first
#define SE second
#define MP make_pair
#define PB push_back
#define SZ(v) v.size();
#define ALL(v) v.begin(), v.end()
#define CLR(v, a) memset(v, a, sizeof(v)); 
#define IT iterator
#define LL long long
#define DB double
#define PI 3.1415926
#define INF 1000000001

#define N 100005

int hd[N], top[N], son[N], sz[N], W[N], idx[N], dep[N], ord[N], pa[N];
int E, n, m, id;
int sl[2], sr[2], sum[2], sm[2];

struct Edge {
        int to, nxt;

        void init(int a, int b) {
                to = a, nxt = b;
        }
}edge[N << 1];

void init() {
        REP(i, 1, n) hd[i] = -1;
        E = 0;
        id = 1;
}

void addEdge(int u, int v) {
        edge[E].init(v, hd[u]);
        hd[u] = E++;
        edge[E].init(u, hd[v]);
        hd[v] = E++;
}

void dfs1(int u, int fa) {
        sz[u] = 1, son[u] = 0;
        pa[u] = fa;

        ADJ(i, u) {
                int v = edge[i].to;
                if (v == fa) continue;
                dep[v] = dep[u] + 1;
                dfs1(v, u);
                if (sz[v] > sz[son[u]]) son[u] = v;
                sz[u] += sz[v];
        }
}

void dfs2(int u, int tu, int fa) {
        ord[id] = u;
        idx[u] = id++, top[u] = tu;
        if (son[u]) 
                dfs2(son[u], tu, u);
        ADJ(i, u) {
                int v = edge[i].to;
                if (v == fa || v == son[u]) continue;
                dfs2(v, v, u);
        }
}

#define lch rt << 1, L, mid
#define rch rt << 1 | 1, mid + 1, R
int lmax[N << 2], rmax[N << 2], mmax[N << 2], sumv[N << 2], lazy[N << 2];

void push_up(int rt) {
        sumv[rt] = sumv[rt << 1 | 1] + sumv[rt << 1];
        lmax[rt] = max(lmax[rt << 1], sumv[rt << 1] + lmax[rt << 1 | 1]);
        rmax[rt] = max(rmax[rt << 1 | 1], sumv[rt << 1 | 1] + rmax[rt << 1]);
        mmax[rt] = max(max(mmax[rt << 1], mmax[rt << 1 | 1]), lmax[rt << 1 | 1] + rmax[rt << 1]);
}

void push_down(int rt, int L, int R) {
        if (lazy[rt] != INF) {
                int val = lazy[rt];
                lazy[rt] = INF;
                int mid = L + R >> 1;
                if (val > 0) {
                        lmax[rt << 1] = mmax[rt << 1] = rmax[rt << 1] = val * (mid - L + 1);
                        lmax[rt << 1 | 1] = mmax[rt << 1 | 1] = rmax[rt << 1 | 1] = val * (R - mid);
                }
                else {
                        lmax[rt << 1] = mmax[rt << 1] = rmax[rt << 1] = val;
                        lmax[rt << 1 | 1] = rmax[rt << 1 | 1] = mmax[rt << 1 | 1] = val;
                }
                sumv[rt << 1] = val * (mid - L + 1);
                sumv[rt << 1 | 1] = val * (R - mid);
                lazy[rt << 1] = val;
                lazy[rt << 1 | 1] = val;
        }
}

void build(int rt, int L, int R) {
        lazy[rt] = INF;
        if (L == R) {
                lmax[rt] = mmax[rt] = rmax[rt] = W[ord[L]];
                sumv[rt] = W[ord[L]];
                return;
        }
        int mid = L + R >> 1;
        build(lch);
        build(rch);
        push_up(rt);
}

void modify(int rt, int L, int R, int l, int r, int val) {

        if (l <= L && R <= r) {
                lmax[rt] = rmax[rt] = mmax[rt] = val < 0 ? val : val * (R - L + 1);
                sumv[rt] = val * (R - L + 1);

                lazy[rt] = val;
                return;
        }
        int mid = L + R >> 1;
        push_down(rt, L, R);
        if (l <= mid)
                modify(lch, l, r, val);
        if (r > mid)
                modify(rch, l, r, val);
        push_up(rt);
}

void query(int rt, int L, int R, int l, int r, bool flag) {
        if (l <= L && R <= r) {
                if (sl[flag] == -INF) {
                        sl[flag] = lmax[rt];
                        sr[flag] = rmax[rt];
                        sm[flag] = mmax[rt];
                        sum[flag] = sumv[rt];
                }
                else {

                        sm[flag] = max(max(sm[flag], mmax[rt]), sl[flag] + rmax[rt]);
                        sl[flag] = max(lmax[rt], sumv[rt] + sl[flag]);
                        sr[flag] = max(sr[flag], sum[flag] + rmax[rt]);
                        sum[flag] += sumv[rt];
                }
                return;
        }
        int mid = L + R >> 1;
        push_down(rt, L, R);
        if (r > mid)
                query(rch, l, r, flag);
        if (l <= mid)
                query(lch, l, r, flag);
}

void update(int u, int v, int val) {
        int tu = top[u], tv = top[v];
        while (tu != tv) {
                if (dep[tu] < dep[tv]) {
                        swap(tu, tv);
                        swap(u, v);
                }

                modify(1, 1, n, idx[tu], idx[u], val);
                u = pa[tu]; 
                tu = top[u];
        }

        if (dep[u] > dep[v]) swap(u, v);
        modify(1, 1, n, idx[u], idx[v], val);
}

int gao(int u, int v) {
        int tu = top[u], tv = top[v];

        sl[0] = sl[1] = -INF;
        sm[0] = sm[1] = -INF;

        while (tu != tv) {
                if (dep[tu] > dep[tv]) {
                        query(1, 1, n, idx[tu], idx[u], 0);
                        u = pa[tu];
                        tu = top[u];
                }
                else {
                        query(1, 1, n, idx[tv], idx[v], 1);
                        v = pa[tv];
                        tv = top[v];
                }
        }

        //cout << sm[0] << ' ' << sl[0] << ' ' << sr[0] << endl;

        if (dep[u] > dep[v]) 
                query(1, 1, n, idx[v], idx[u], 0);
        else
                query(1, 1, n, idx[u], idx[v], 1);

        //cout << sm[1] << ' ' << sl[1] << ' ' << sr[1] << endl;

        int res = max(max(sm[0], sm[1]), sl[0] + sl[1]);
        return max(res, 0);
}       

int main() {
        int a, b, c, d;
        while (~scanf("%d", &n)) {
                REP(i, 1, n) scanf("%d", W + i);
                init();
                FOR (i, n - 1) {
                        scanf("%d%d", &a, &b);
                        addEdge(a, b);
                }
                dfs1(1, -1);
                dfs2(1, 1, -1);

                build(1, 1, n);
                scanf("%d", &m);

                FOR (i, m) {
                        scanf("%d%d%d", &a, &b, &c);
                        if (a == 1) 
                                printf("%d\n", gao(b, c));
                        else {
                                scanf("%d", &d);
                                update(b, c, d);
                        }
                }

        }
        return 0;
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章