けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 302 Ex - Ball Collector (橙色, 625 点)

undo 付き Union-Find!

問題概要

頂点数  N の木が与えられる。各頂点には、数値  A_{i} の書かれたボールと、数値  B_{i} の書かれたボールがある。

 k = 2, 3, \dots, N に対して、次の問に答えよ。

  • パス  1- k 上の各頂点から
  • ボールを 1 個ずつ選んだときの
  • ボールに書かれた数値の種類数の最大値を答えよ。

制約

  •  2 \le N \le 2 \times 10^{5}
  •  1 \le A_{i}, B_{i} \le N

考えたこと

木に対する問題を解くときには、まずパスについて解くと良い場合が多いと思う。

パスについての問題とは、すなわち「各  i に対して  A_{i} B_{i} かを選んだときの、数値の種類数の最大値」を求める問題だ。

そして、実はそれはすでに出題されている!

drken1215.hatenablog.com

Union-Find を使う。各  i に対して、数値  A_{i} に対応する頂点と数値  B_{i} に対応する頂点をマージしていく。そうして、各グループ (サイズ  s とする)

  • 辺数 (merge が呼ばれた回数) が  s-1 ならば、 s-1
  • 辺数が  s 以上ならば、 s

を足したものが答えとなる。

木だと

木において各頂点について考える場合も、同様のことが成り立つ。ただし、木を DFS するとき、「探索済みの頂点から親頂点へと戻る」というバックトラックの動きをすることになる。

よって、Union-Find において「merge(A[i], B[i]) をなかったことにする操作」ができるようにする必要がある。

そのためには、undo 付きの Union-Find が使える。全体として、計算量は  O(N \log N) となる (経路圧縮をしないことに注意)。

コード

#include <bits/stdc++.h>
using namespace std;


// Union-Find, we can undo
struct UnionFind {
    // core member
    vector<int> par;
    stack<pair<int,int>> history;
    
    // constructor
    UnionFind() {}
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else return root(par[x]);
    }
    
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x), y = root(y);
        history.emplace(x, par[x]);
        history.emplace(y, par[y]);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y);  // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }

    // 1-step undo
    void undo() {
        for (int iter = 0; iter < 2; ++iter) {
            par[history.top().first] = history.top().second;
            history.pop();
        }
    }

    // erase history
    void snapshot() {
        while (!history.empty()) history.pop();
    }

    // all rollback
    void rollback() {
        while (!history.empty()) undo();
    }
    
    // debug
    friend ostream& operator << (ostream &s, UnionFind uf) {
        map<int, vector<int>> groups;
        for (int i = 0; i < uf.par.size(); ++i) {
            int r = uf.root(i);
            groups[r].push_back(i);
        }
        for (const auto &it : groups) {
            s << "group: ";
            for (auto v : it.second) s << v << " ";
            s << endl;
        }
        return s;
    }
};

void ABC_302_Ex() {
    using pint = pair<int,int>;
    using Graph = vector<vector<int>>;
    
    int N;
    cin >> N;
    vector<int> A(N), B(N);
    for (int i = 0; i < N; ++i) {
        cin >> A[i] >> B[i];
        --A[i], --B[i];
    }
    Graph G(N);
    for (int i = 0; i < N-1; ++i) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    
    // Union-Find
    // 外部データの undo もここで実現する
    UnionFind uf(N);
    vector<int> nums(N, 0);  // 各連結成分ごとの種類数
    int cur = 0;  // 現時点での種類数の最大値
    vector<pair<pint,int>> hist;  // 履歴
    vector<int> res(N, 0);  // 答え
    auto insert = [&](int v, int p) -> void {
        int x = uf.root(A[v]), y = uf.root(B[v]);
        hist.push_back(make_pair(pint(x, nums[x]), cur));
        hist.push_back(make_pair(pint(y, nums[y]), cur));
        int before = (uf.same(x, y) ? nums[x] : nums[x] + nums[y]);
        int after = (uf.same(x, y) ?
                     min(uf.size(x), nums[x] + 1) :
                     min(uf.size(x) + uf.size(y), nums[x] + nums[y] + 1));
        uf.merge(x, y);
        nums[uf.root(x)] = after;
        res[v] = (p != -1 ? res[p] : 0) + (after - before);
    };
    auto erase = [&]() -> void {
        for (int iter = 0; iter < 2; ++iter) {
            nums[hist.back().first.first] = hist.back().first.second;
            cur = hist.back().second;
            hist.pop_back();
        }
        uf.undo();
    };
    
    // DFS
    auto dfs = [&](auto self, int v, int p) -> void {
        insert(v, p);
        for (auto v2 : G[v]) {
            if (v2 == p) continue;
            self(self, v2, v);
        }
        erase();
    };
    dfs(dfs, 0, -1);
    
    for (int v = 1; v < N; ++v) cout << res[v] << " ";
    cout << endl;
}

int main() {
    ABC_302_Ex();
}