けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #680 (Div. 1) C. Team-Building (R2500)

undo 付き Union-Find ってなんぞ!?

問題概要

頂点数  N、辺数  M の単純無向グラフが与えられる。色が  K 種類あって、各頂点は  1, 2, \dots, K のいずれかの色で塗られている。このとき、以下の条件を満たすような色の組  (i, j) ( 1 \le i \lt j \le K) の個数を求めよ。

  •  N 個の頂点のうち、色  i の頂点と色  j の頂点のみからなる部分グラフを考えたとき、その部分グラフが二部グラフである

制約

  •  1 \le N \le 5 \times 10^{5}
  •  0 \le M \le 5 \times 10^{5}
  •  2 \le K \le 5 \times 10^{5}

考えたこと

単純に全色ペアについて二部グラフ判定したのでは  O(K^{2}(N + M)) の計算量がかかりそう。

少し工夫すると、ダメになる可能性のある色のペアは、両端点の色が異なるような辺  e = (u, v) の両端の色だけであることに注目して、次のようにできる。

  • 各色  k = 1, 2, \dots, K に対してその色の頂点のみで二部グラフ判定を行う

    • それ自体が二部グラフじゃなかったらその色を含むペアはすべてダメ
    • 色単体では二部グラフであるような色 (可能色と呼ぶことにする) の個数を  L とする
  •  \frac{1}{2}L(L-1) から、以下の条件を満たすとき、色ペアを引いていく

    • 各辺  e = (u, v) に対して
    •  c_{u}, c_{v} がともに可能色で、かつ異なる場合について
    •  c_{u}, c_{v} のいずれかの頂点のみからなる部分グラフが二部グラフでないとき

単純にこれをやると  O(M(N+M)) の計算量となる。これでもまだ足りない。

なお、二部グラフ判定は Union-Find でやる方法があって、それを用いると  O(M(N+M)\alpha(N)) となる。具体的には次のようにする。

  • 各頂点を倍加させて、辺  (u, v) に対しては  u, v+N u+N, v を繋ぐ
  • ある頂点  v に対して  v, v+N が同じグループにあったらダメ

rollback つき Union-Find

Union-Find の undo / rollback ができるならば、次のようにできる!

  • まずは両端が同じ色であるような辺のみ、Union-Find を構築する
  • Union-Find の snapshot をとっておく
  • 各辺  e = (u, v) を両端点の色組ごとに分類する
  • 各色組ごとに
    • 該当する辺をすべて Union-Find でマージして「二部グラフが壊れるか」を判定していく
    • 終了後に rollback する

これを実現するための rollback つき Union-Find は、「経路圧縮」を取り止めているので計算量がやや悪化して  O(\alpha(N)) から  O(\log N) になる。

なお、各辺に対して merge と undo がそれぞれ高々 1 回以下呼ばれることになるので、全体を通した計算量は  O(K + (N + M)\log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// rollback Union-Find
struct UnionFind {
    vector<int> par;
    stack<pair<int,int>> history;
    
    UnionFind() {}
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    int root(int x) {
        if (par[x] < 0) return x;
        else return root(par[x]);
    }
    
    bool issame(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x); y = root(y);
        history.emplace(x, par[x]);
        history.emplace(y, par[y]);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }

    // 1-step undo
    void undo() {
        for (int iter = 0; iter < 2; ++iter) {
            par[history.top().first] = history.top().second;
            history.pop();
        }
    }

    // erase history
    void snapshot() {
        while (!history.empty()) history.pop();
    }

    // all rollback
    void rollback() {
        while (!history.empty()) undo();
    }
};



//////////////////////////////////////////
// solver
//////////////////////////////////////////

using pint = pair<int,int>;
using Graph = vector<vector<int>>;
long long CF680DIV1C(int N, int M, int K, const vector<int> &C, const Graph &G, UnionFind &uf) {
    // これまでの履歴の削除
    uf.snapshot();

    // すでに二部グラフはダメ
    vector<bool> isbi(K, true);
    long long con = K;
    for (int v = 0; v < N; ++v) {
        if (!isbi[C[v]]) continue;
        if (uf.issame(v, v+N)) --con, isbi[C[v]] = false;
    }
    long long res = con * (con - 1) / 2;

    // 辺の色を分類
    map<pint, vector<pint>> ma;
    for (int v = 0; v < N; ++v) {
        for (auto u : G[v]) {
            if (!isbi[C[v]] || !isbi[C[u]] || C[v] == C[u]) continue;
            ma[pint(min(C[v], C[u]), max(C[v], C[u]))].push_back(pint(v, u));
        }
    }

    // 各ペアごとに追加していく
    for (auto it : ma) {
        bool ok = true;
        for (auto e : it.second) {
            int u = e.first, v = e.second;
            uf.merge(u, v+N), uf.merge(u+N, v);
            if (uf.issame(u, u+N) || uf.issame(v, v+N)) ok = false;
        }
        if (!ok) --res;

        // 元に戻す
        uf.rollback();
    }
    return res;
}

int main() {
    cin.tie(0); 
    ios::sync_with_stdio(false);

    int N, M, K;
    cin >> N >> M >> K;
    vector<int> C(N);
    for (int i = 0; i < N; ++i) cin >> C[i], --C[i];
    Graph G(N);
    UnionFind uf(N*2);
    for (int i = 0; i < M; ++i) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        G[u].push_back(v), G[v].push_back(u);
        if (C[u] == C[v]) uf.merge(u, v+N), uf.merge(u+N, v);
    }
    cout << CF680DIV1C(N, M, K, C, G, uf) << endl;
}