けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

JOIG 2023 E - 運河 (難易度 7)

UnionFind を使って、差分更新を頑張る!
UnionFind はオンラインの処理を簡単に実現できることが強みで、それを問いかける教育的問題ですね。他にも「左右からの結果を前処理する」という典型テクも使います!

問題概要

 H \times W のグリッドがあって、各マス  (i, j) には整数値  A_{i, j} が書かれている。

このグリッドに対して、縦方向に切り取り線を入れて、左右に分割する。そして、左右それぞれの長方形について、「同じ数値が上下左右に隣接している領域」が何個あるかを求める。このとき、

(左側の長方形における、領域の個数) + (右側の長方形における、領域の個数)

として考えられる最小値を求めよ。

制約

  •  H \times W \le 10^{5}
  •  1 \le A_{i, j} \le 10^{9}

解法

まず、

  •  N 個のものが並んでいるときに、左右に分割する方法を、すべて試す (今回の問題もこのタイプ)
  •  N 個のものが並んでいるときに、そのうちの 1 個を除去した残りの  N-1 個について考える作業を、すべて試す

というタイプの問題において、高い確率で有効打となる典型テクニックがあります。それは、次の結果をあらかじめ求めておくことです。


  • left[i] ← 左から  i 個分についての結果
  • right[i] ← 右から  i 個分についての結果

これらをあらかじめ求めておけば、たとえば「 N 個のものについて左から  l 個分の箇所で左右に分割するときのスコア」を求めたいときは、

left[i] + right[N - i] (問題によっては max(left[i], right[N - i]) などのこともある)

などと求められます。そして、多くの問題において、配列 left の求め方さえ考えればよいです。配列 rightleft と同様に求められることが多いためです。

 

今回の場合

今回の問題は、次の配列 left が高速に求められれば良いと言えます。

  • left[i] ← グリッドのうち、左から  H \times i の領域について、同じ数値が上下左右に隣接している箇所を 1 つの領域とみなしたときに、何個の領域に分かれるか

このような、連結成分の個数を求める問題においては、Union-Find が活躍しますね。特に、今回は


 i = 1, 2, \dots, W の順に、

left[i-1] の結果をもとにして、left[i] の結果を求める


という処理をしたいですね。このように差分を毎回更新したい場面では、UnionFind は強いです。今回は  HW 個あるマスに対応する  HW 個のノードを用意しておいて、各  i = 1, 2, \dots, W に対して次の処理をすればよいでしょう。


 i H 個のマスについて、上下左 (右は不要) に同じ数値が隣接しているマス同士について、Union-Find 上で対応するノード同士を併合していく


また、今回は、UnionFind において、「いま何個のグループに分かれているか」を管理する値を持っておきましょう。

  • num ← いま何個のグループに分かれているか (初期値は  WH)

Union-Find 上で併合処理を行うときに、グループが 1 個減るならば、num の値を 1 減らすようにします。最後に、左から  i 列分の処理を終えた段階で「連結成分の個数」は num - H * (W - i) となります。右の  H(W-i) 個のマスについてはまだ考えていない状態なので、その個数を引いています。

 

計算量

以上の解法の計算量を見積ります。

  • leftright を求める計算量: O(HW \alpha(HW))
  • その後の処理に要する計算量: O(W)

ということで、全体の計算量は  O(HW \alpha(HW)) となります。 HW \le 5 \times 10^{5} という制約があるため、十分間に合います。

 

コード

#include <bits/stdc++.h>
using namespace std;

// Union-Find
struct UnionFind {
    // core member
    vector<int> par, nex;

    // constructor
    UnionFind() { }
    UnionFind(int N) : par(N, -1), nex(N) {
        init(N);
    }
    void init(int N) {
        par.assign(N, -1);
        nex.resize(N);
        for (int i = 0; i < N; ++i) nex[i] = i;
    }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        swap(nex[x], nex[y]);
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
    
    // get groups
    vector<vector<int>> groups() {
        vector<vector<int>> member(par.size());
        for (int v = 0; v < (int)par.size(); ++v) {
            member[root(v)].push_back(v);
        }
        vector<vector<int>> res;
        for (int v = 0; v < (int)par.size(); ++v) {
            if (!member[v].empty()) res.push_back(member[v]);
        }
        return res;
    }
    
    // debug
    friend ostream& operator << (ostream &s, UnionFind uf) {
        const vector<vector<int>> &gs = uf.groups();
        for (const vector<int> &g : gs) {
            s << "group: ";
            for (int v : g) s << v << " ";
            s << endl;
        }
        return s;
    }
};

int main() {
    // 入力
    int H, W;
    cin >> H >> W;
    vector A(H, vector<int>(W));
    for (int i = 0; i < H; ++i) for (int j = 0; j < W; ++j) cin >> A[i][j];
    
    // マス (i, j) を表す ID
    auto cid = [&](int i, int j) -> int {
        return i * W + j;
    };
    
    // 盤面 A において、左側から i 列分における領域の個数 num[j] を求める
    auto calc = [&]() -> vector<int> {
        vector<int> left(W+1, 0);
        UnionFind uf(H * W);  // 各マスの領域関係
        int num = H * W;  // Union-Find 上のグループの個数
        for (int i = 1; i <= W; ++i) {
            // 列 j を追加して、隣接関係を整理する
            for (int x = 0; x < H; ++x) {
                if (i-2 >= 0 && A[x][i-1] == A[x][i-2]
                    && !uf.same(cid(x, i-1), cid(x, i-2))) {
                    uf.merge(cid(x, i-1), cid(x, i-2));
                    --num;
                }
                if (x-1 >= 0 && A[x][i-1] == A[x-1][i-1]
                    && !uf.same(cid(x, i-1), cid(x-1, i-1))) {
                    uf.merge(cid(x, i-1), cid(x-1, i-1));
                    --num;
                }
            }
            // 連結成分数 - H × (W - i) が答え
            left[i] = num - H * (W - i);
        }
        return left;
    };
    
    // left を求める
    vector<int> left = calc();
    
    // right を求める (左右反転して同じことをするのが楽)
    for (int i = 0; i < H; ++i) reverse(A[i].begin(), A[i].end());
    vector<int> right = calc();

    // 左から l 列、右から W - l 列で分けた結果を調べる
    int res = H * W;
    for (int l = 1; l < W; ++l) res = min(res, left[l] + right[W - l]);
    cout << res << endl;
}