[SDOI2017]切树游戏

题目链接

immortalCO UOJ

上面链接里写的很详细,dp方程就是把自己和所有儿子依次做异或卷积。fwt之后可以转化为每一位上的乘法。

写出乘法的转移矩阵,做动态dp即可。注意模数很小,需要手写一个在模意义下支持乘除 latex 0 的整数类型。

#include <bits/stdc++.h>
using namespace std;
int read() {
    int x = 0, f = 1; char ch = getchar();
    while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getchar(); }
    while (isdigit(ch)) { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}

const int Max = 30033;
const int Max4 = Max * 4;
const int mod = 10007;
const int MaxM = 129;
const int inv2 = 5004;
int n, m, q, cnte, fir[Max], w[Max];
int e[MaxM][MaxM], inv[Max], ans[MaxM], tans[MaxM];
char opt[10];
void Add(int &a, int b) {
    a = (a + b) % mod;
}
void Dec(int &a, int b) {
    a = (a - b + mod) % mod;
}
int qpow(int x, int k) {
    int re = 1;
    for (; k; k >>= 1, x = x * x % mod)
        if (k & 1) re = re * x % mod;
    return re;
}
struct edge {
    int to, nxt;
    edge() {}
    edge(int a, int b) {
        to = a, nxt = b;
    }
} edg[Max * 2];
void addedge(int u, int v) {
    edg[++cnte] = edge(v, fir[u]), fir[u] = cnte;
    edg[++cnte] = edge(u, fir[v]), fir[v] = cnte;
}
void fwt(int *a, int n) {
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j += (i << 1)) {
            for (int k = 0; k < i; k++) {
                int x = a[j + k], y = a[j + k + i];
                a[j + k] = (x + y) % mod;
                a[j + k + i] = (x - y + mod) % mod;
            }
        }
    }
}
void ifwt(int *a, int n) {
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j += (i << 1)) {
            for (int k = 0; k < i; k++) {
                int x = a[j + k], y = a[j + k + i];
                a[j + k] = (x + y) * inv2 % mod;
                a[j + k + i] = (x - y + mod) * inv2 % mod;
            }
        }
    }
}
struct int0 {
    int val, c0;
    int0() { val = c0 = 0; }
    int0(int x) {
        if (x) val = x, c0 = 0;
        else val = 1, c0 = 1;
    }
    friend int0 operator* (int0 a, int b) {
        if (!b) a.c0++;
        else a.val = a.val * b % mod;
        return a;
    }
    friend int0 operator/ (int0 a, int b) {
        if (!b) a.c0--;
        else a.val = a.val * inv[b] % mod;
        return a;
    }
    int getv() { if (c0) return 0; else return val; }
} f[Max][MaxM];
void init() {
    for (int i = 0; i < m; i++) {
        e[i][i] = 1;
        fwt(e[i], m);
    }
    for (int i = 1; i < mod; i++)
        inv[i] = qpow(i, mod - 2);
    for (int i = 1; i <= n; i++)
        for (int j = 0; j < m; j++)
            f[i][j] = int0(e[w[i]][j]);
}

vector<int> p[Max];
int t[Max], k;
int siz[Max], dep[Max], fa[Max], top[Max], big[Max];
bool cmp(int a, int b) { return dep[a] > dep[b]; }
void dfs1(int u) {
    siz[u] = 1;
    for (int i = fir[u]; i; i = edg[i].nxt)
        if (edg[i].to != fa[u]) {
            fa[edg[i].to] = u;
            dep[edg[i].to] = dep[u] + 1;
            dfs1(edg[i].to), siz[u] += siz[edg[i].to];
            if (siz[edg[i].to] > siz[big[u]]) big[u] = edg[i].to;
        }
}
void dfs2(int u, int tp) {
    top[u] = tp;
    if (top[u] == u) t[++k] = u;
    p[top[u]].emplace_back(u);
    if (big[u]) dfs2(big[u], tp);
    for (int i = fir[u]; i; i = edg[i].nxt)
        if (edg[i].to != fa[u] && edg[i].to != big[u])
            dfs2(edg[i].to, edg[i].to);
}

int h[Max4][MaxM], lv[Max4][MaxM], rv[Max4][MaxM], s[Max4][MaxM];
int tot, rt[Max], ls[Max4], rs[Max4], tf[Max4], pos[Max];
void pushup(int x) {
    for (int i = 0; i < m; i++) {
        h[x][i] = (h[ls[x]][i] + h[rs[x]][i] + rv[ls[x]][i] * lv[rs[x]][i]) % mod;
        lv[x][i] = (lv[ls[x]][i] + s[ls[x]][i] * lv[rs[x]][i]) % mod;
        rv[x][i] = (rv[rs[x]][i] + s[rs[x]][i] * rv[ls[x]][i]) % mod;
        s[x][i] = s[ls[x]][i] * s[rs[x]][i] % mod;
    }
}
int build(int l, int r, int t) {
    int x = ++tot;
    if (l == r) {
        for (int i = 0; i < m; i++)
            h[x][i] = lv[x][i] = rv[x][i] = s[x][i] = f[p[t][l - 1]][i].getv();
        pos[p[t][l - 1]] = x;
        return x;
    }
    int mid = (l + r) >> 1;
    ls[x] = build(l, mid, t), rs[x] = build(mid + 1, r, t);
    tf[ls[x]] = tf[rs[x]] = x;
    pushup(x);
    return x;
}
void update(int id) {
    int x = pos[id], tp = top[id];
    if (fa[tp]) {
        for (int i = 0; i < m; i++)
            f[fa[tp]][i] = f[fa[tp]][i] / ((lv[rt[tp]][i] + e[0][i]) % mod);
    }
    for (int i = 0; i < m; i++)
        Dec(ans[i], h[rt[tp]][i]);
    for (int i = 0; i < m; i++)
        h[x][i] = lv[x][i] = rv[x][i] = s[x][i] = f[id][i].getv();
    for (x = tf[x]; x; x = tf[x])
        pushup(x);
    if (fa[tp]) {
        for (int i = 0; i < m; i++)
            f[fa[tp]][i] = f[fa[tp]][i] * ((lv[rt[tp]][i] + e[0][i]) % mod);
    }
    for (int i = 0; i < m; i++)
        Add(ans[i], h[rt[tp]][i]);
}

int main() {
    n = read(), m = read();
    for (int i = 1; i <= n; i++)
        w[i] = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        addedge(u, v);
    }
    init();
    dfs1(1), dfs2(1, 1);
    sort(t + 1, t + k + 1, cmp);
    for (int i = 1; i <= k; i++) {
        int x = t[i];
        rt[x] = build(1, p[x].size(), x);
        for (int j = 0; j < m; j++)
            Add(ans[j], h[rt[x]][j]);
        if (fa[x]) {
            for (int j = 0; j < m; j++)
                f[fa[x]][j] = f[fa[x]][j] * ((lv[rt[x]][j] + e[0][j]) % mod);
        }
    }
    q = read();
    while (q--) {
        scanf("%s", opt);
        if (opt[0] == 'Q') {
        	int v = read();
            for (int i = 0; i < m; i++)
                tans[i] = ans[i];
            ifwt(tans, m);
            printf("%d\n", tans[v]);
        }
        else {
            int x = read(), v = read();
            for (int i = 0; i < m; i++)
                f[x][i] = f[x][i] / e[w[x]][i];
            for (int i = 0; i < m; i++)
                f[x][i] = f[x][i] * e[v][i];
            w[x] = v;
            for (; x; x = fa[top[x]]) update(x);
        }
    }
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注