[CTSC2018]暴力写挂

题目链接

两颗[latex]n[/latex]个节点的树,求[latex]\max\limits_{1\le i\le n,1\le j\le n}\{depth(i)+depth(j)-depth(lca(i,j))-depth'(lca'(i,j))\}[/latex]

[latex]depth(i)+depth(j)-depth(lca(i,j))[/latex]其实可以看做[latex]dis(i,lca(i,j))+dis(i,lca(i,j))+dis(root,lca(i,j))[/latex]。

考虑边分治。

对于分支中心[latex]z[/latex]的子树维护[latex]f(x)=dis(x,root)[/latex];
对于其他点维护[latex]g(y)[/latex],表示从[latex]y[/latex]到链[latex]root\rightarrow z[/latex]的距离。

显然[latex]f(x)+g(y)=dis(x,lca)+dis(y,lca)+dis(root,lca)[/latex]。

每一次分治对[latex]Tree2[/latex]建虚树,在虚树上用儿子的[latex]f(x),g(x)[/latex]更新父亲的,然后在每一个[latex]lca[/latex]处更新答案,就相当于在[latex]lca'[/latex]处讨论了所有最近公共祖先为[latex]lca'[/latex]的点对,单次复杂度就降低到了[latex]O(n)[/latex]。

时间复杂度[latex]O(nlogn)[/latex]。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

#define ll long long
const int Max = 800000;
const ll inf = 9999999999999999ll;

ll read() {
    ll x = 0, f = 1; char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (isdigit(ch)) {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}

struct edge {
    int to, nxt;
    ll w;
    edge() {}
    edge(int a, int b, ll c) {
        to = a, nxt = b, w = c;
    }
};

int n;

struct Tree {
    int N, cntd, cnte, in[Max], out[Max], id[Max];
    int dep[Max], fa[Max], top[Max], fir[Max], siz[Max];
    ll dis[Max];
    edge e[Max * 2];
    void addedge(int u, int v, ll w) {
        e[++cnte] = edge(v, fir[u], w), fir[u] = cnte;
        e[++cnte] = edge(u, fir[v], w), fir[v] = cnte;
    }
    void deledge(int u, int v) {
        if (!u || !v) return;
        if (e[fir[u]].to == v)
            fir[u] = e[fir[u]].nxt;
        for (int pre = fir[u], i = e[fir[u]].nxt; i; pre = i, i = e[i].nxt)
            if (e[i].to == v) {
                e[pre].nxt = e[i].nxt;
                break;
            }
    }
    void dfs1(int u) {
        siz[u] = 1, in[u] = ++cntd, id[cntd] = u;
        for (int i = fir[u]; i; i = e[i].nxt)
            if (e[i].to != fa[u]) {
                fa[e[i].to] = u;
                dep[e[i].to] = dep[u] + 1;
                dis[e[i].to] = dis[u] + e[i].w;
                dfs1(e[i].to);
                siz[u] += siz[e[i].to];
            }
        out[u] = cntd;
    }
    void dfs2(int u, int tp) {
        int big = 0; top[u] = tp;
        for (int i = fir[u]; i; i = e[i].nxt)
            if (e[i].to != fa[u] && siz[e[i].to] > siz[big])
                big = e[i].to;
        if (big) dfs2(big, tp);
        for (int i = fir[u]; i; i = e[i].nxt)
            if (e[i].to != fa[u] && e[i].to != big)
                dfs2(e[i].to, e[i].to);
    }
    int getlca(int x, int y) {
        while (top[x] != top[y]) {
            if (dep[top[x]] < dep[top[y]])
                swap(x, y);
            x = fa[top[x]];
        }
        return dep[x] < dep[y] ? x : y;
    }
} T0, T1, T2;

void rebuild(int u, int Fa) {
    bool flag = 0;
    for (int i = T0.fir[u], x = u; i; i = T0.e[i].nxt) {
        int v = T0.e[i].to;
        if (v == Fa) continue;
        if (flag) {
            T1.N++;
            T1.addedge(x, T1.N, 0);
            x = T1.N;
        }
        T1.addedge(x, v, T0.e[i].w);
        rebuild(v, u);
        flag = 1;
    }
}

int tot, top, L, R, lca[Max], a[Max], que[Max];
int fa[Max], siz[Max], sta[Max], ta[Max], tl[Max];
ll ans, f[Max][2];

void update(int x, int y) {
    ans = max(ans, max(f[x][0] + f[y][1], f[x][1] + f[y][0]) - T2.dis[x]);
    f[x][0] = max(f[x][0], f[y][0]);
    f[x][1] = max(f[x][1], f[y][1]);
}

void solve(int rt, int l, int r) {
    if (l > r) return;
    if (l == r) {
        ans = max(ans, T1.dis[a[l]] - T2.dis[a[l]]);
        return;
    }
    int mid;
    L = R = 1;
    que[R++] = rt, fa[rt] = 0;
    while (L < R) {
        int u = que[L++];
        siz[u] = 1;
        for (int i = T1.fir[u]; i; i = T1.e[i].nxt) {
            int v = T1.e[i].to;
            if (v == fa[u]) continue;
            fa[v] = u, que[R++] = v;
        }
    }
    R--;
    for (int i = R; i; i--)
        if (fa[que[i]])
            siz[fa[que[i]]] += siz[que[i]];
    int mn = 123456789;
    for (int i = R; i; i--) {
        int big = max(siz[rt] - siz[que[i]], siz[que[i]]);
        if (big < mn) {
            mn = big, mid = que[i];
        }
    }
    for (int i = l; i < r; i++)
        f[lca[i]][0] = f[lca[i]][1] = -inf;
    for (int i = 1; i <= R; i++) {
        int x = que[i];
        int in = T1.in[x], out = T1.out[x];
        if (in < T1.in[mid] || in > T1.out[mid]) {
            if (in <= T1.in[mid] && T1.in[mid] <= out) {
                f[x][1] = -inf, f[x][0] = 0;
            }
            else {
                f[x][1] = -inf, f[x][0] = f[fa[x]][0] + T1.dis[x] - T1.dis[fa[x]];
            }
        }
        else {
            f[x][1] = T1.dis[x], f[x][0] = -inf;
        }
    }

    sta[top = 1] = a[l];
    for (int i = l + 1; i <= r; i++) {
        int x = a[i], y = lca[i - 1];
        while (top > 1 && T2.in[sta[top - 1]] >= T2.in[y]) {
            update(sta[top - 1], sta[top]);
            top--;
        }
        if (sta[top] == y)
            sta[++top] = x;
        else {
            update(y, sta[top]);
            sta[top] = y, sta[++top] = x;
        }
    }
    for (; top > 1; top--)
        update(sta[top - 1], sta[top]);

    tot = 0;
    for (int i = l; i <= r; i++) {
        int dfn = T1.in[a[i]];
        if (dfn >= T1.in[mid] && dfn <= T1.out[mid])
            ta[++tot] = tl[tot] = a[i];
        if (tot && T2.in[tl[tot]] > T2.in[lca[i]])
            tl[tot] = lca[i];
    }
    int tmp = tot;
    for (int i = l; i <= r; i++) {
        int dfn = T1.in[a[i]];
        if (dfn < T1.in[mid] || dfn > T1.out[mid])
            ta[++tot] = tl[tot] = a[i];
        if (tot && T2.in[tl[tot]] > T2.in[lca[i]])
            tl[tot] = lca[i];
    }
    for (int i = 0; i <= r - l; i++) {
        lca[l + i] = tl[i + 1];
        a[l + i] = ta[i + 1];
    }
    T1.deledge(mid, fa[mid]);
    T1.deledge(fa[mid], mid);
    solve(mid, l, l + tmp - 1);
    solve(rt, l + tmp, r);
}

int main() {
    n = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        ll w = read();
        T0.addedge(u, v, w);
    }
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        ll w = read();
        T2.addedge(u, v, w);
    }
    T1.N = n;
    rebuild(1, 0);
    T1.dfs1(1), T2.dfs1(1);
    T2.dfs2(1, 1);
    for (int i = 1; i < n; i++)
        lca[i] = T2.getlca(T2.id[i], T2.id[i + 1]);
    for (int i = 1; i <= n; i++)
        a[i] = T2.id[i];
	solve(1, 1, n);
    printf("%lld\n", ans);
    return 0;
}

1人评论了“[CTSC2018]暴力写挂”

SYCstudio进行回复 取消回复

电子邮件地址不会被公开。 必填项已用*标注