けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 002 D - Stamp Rally (1000 点)

部分永続 Union-Find 木の練習をした。

問題概要 (AGC 002 D)

N 頂点 M 辺の無向グラフがあります。 グラフは連結です。以下の Q 個のクエリに答えよ:

  • 頂点 x, y が与えられ、「x と y を含む z 個の頂点からなる集合に含まれる頂点の番号の最大値」の最小値を求めよ

解法 1: 部分永続 Union-Find

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
using namespace std;


using pint = pair<int,int>;
struct PartiallyPersistentUnionFind {
    vector<int> par, last;
    vector<vector<pint> > history;
    
    PartiallyPersistentUnionFind(int n) : par(n, -1), last(n, -1), history(n) {
        for (auto &vec : history) vec.emplace_back(-1, -1);
    }
    void init(int n) {
        par.assign(n, -1); last.assign(n, -1); history.assign(n, vector<pint>());
        for (auto &vec : history) vec.emplace_back(-1, -1);
    }
    
    int root(int t, int x) {
        if (last[x] == -1 || t < last[x]) return x;
        return root(t, par[x]);
    }
    
    bool issame(int t, int x, int y) {
        return root(t, x) == root(t, y);
    }
    
    bool merge(int t, int x, int y) {
        x = root(t, x); y = root(t, y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        last[y] = t;
        history[x].emplace_back(t, par[x]);
        return true;
    }
    
    int size(int t, int x) {
        x = root(t, x);
        return -prev(lower_bound(history[x].begin(), history[x].end(), make_pair(t, 0)))->second;
    }
};



int main() {
    int N, M, Q; cin >> N >> M;
    PartiallyPersistentUnionFind uf(N);
    for (int t = 0; t < M; ++t) {
        int a, b; cin >> a >> b; --a, --b;
        uf.merge(t+1, a, b);
    }
    cin >> Q;
    for (int q = 0; q < Q; ++q) {
        int x, y, z; cin >> x >> y >> z; --x, --y;
        int low = 0, high = M + 10;
        while (high - low > 1) {
            int mid = (low + high) / 2;
            int num = 0;
            if (uf.issame(mid, x, y)) num = uf.size(mid, x);
            else num = uf.size(mid, x) + uf.size(mid, y);
            if (num >= z) high = mid;
            else low = mid;
        }
        cout << high << endl;
    }
}

解法 2: 並列二分探索

#include <iostream>
#include <vector>
using namespace std;

struct UnionFind {
    vector<int> par;
    
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool issame(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x); y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
};

int main() {
    int N, M, Q;
    cin >> N >> M;
    vector<int> A(M), B(M);
    for (int i = 0; i < M; ++i) {
        scanf("%d %d", &A[i], &B[i]);
        --A[i], --B[i];
    }
    cin >> Q;
    vector<int> X(Q), Y(Q), Z(Q);
    for (int i = 0; i < Q; ++i) {
        scanf("%d %d %d", &X[i], &Y[i], &Z[i]);
        --X[i], --Y[i];
    }
    vector<int> le(Q), ri(Q);
    vector<vector<int> > vec(M);
    for (int i = 0; i < Q; ++i) le[i] = 0, ri[i] = M;
    while (true) {
        bool update = false;
        for (int i = 0; i < (int)vec.size(); ++i) vec[i].clear();
        for (int i = 0; i < Q; ++i) {
            if (ri[i] - le[i] > 1) {
                update = true;
                int mid = (le[i] + ri[i]) / 2;
                vec[mid].push_back(i);
            }
        }
        if (!update) break;
        UnionFind uf(N);
        for (int mid = 0; mid < M; ++mid) {
            for (auto q : vec[mid]) {
                int wa = uf.size(X[q]);
                if (!uf.issame(X[q], Y[q])) wa += uf.size(Y[q]);
                if (wa >= Z[q]) ri[q] = mid;
                else le[q] = mid;
            }
            uf.merge(A[mid], B[mid]);
        }
    }
    for (int q = 0; q < Q; ++q) cout << ri[q] << endl;
}