けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 106 B - Values (茶色, 400 点)

問題概要

頂点数  N、辺数  M の無向グラフが与えられる。各頂点  v には値  a_{i} が書かれている。以下の操作を好きな順序で好きな回数だけ行うことで、各頂点  v の数値が  b_{v} であるような状態にすることが可能かどうかを判定せよ。

  •  (u, v) を選んで、以下のいずれかを行う
    •  a_{u} を +1 して、 a_{v} を -1 する
    •  a_{u} を -1 して、 a_{v} を +1 する

制約

  •  N, M \le 2 \times 10^{5}

考えたこと

まず明らかに、 a の総和と  b の総和が等しくなければならない。

でも少し色々試してみると、 a の総和と  b の総和が等しくて、グラフが連結でありさえすれば、できそうだというのも見えてくる。

具体的には、グラフのうちのどこか 1 つの頂点  v について、 a_{v} =  b_{v} となるように調整することができるはずだ。そうしたら残りの頂点からまたどこか 1 つの頂点を揃えて...と繰り返していくと、最終的に全体の辻褄があって全体が揃いそう。

一応注意点として、揃える頂点として「関節点」を選んではいけない。関節点を選んでしまうと、残りの頂点がバラバラになってしまうからだ。木であれば「葉」を揃えていけば、ちょうど「葉」を切り落とすような感じで行ける。一般のグラフであっても、連結であれば全域木をとることができるので、その全域木において「葉」から順番に揃えて切り落としていくイメージで OK。

以上から、

  • 連結なグラフにおいて
  •  a の総和と  b の総和が等しい

という条件を満たせば、可能であることがわかった。連結でない一般のグラフについては、「すべての連結成分について  a の総和と  b の総和が等しい」ことが条件となる。

コード

グラフを連結成分ごとに分けるのに、Union-Find を用いた。この場合の計算量は  O(N + M\alpha(N)) となる。

#include <bits/stdc++.h>
using namespace std;

struct UnionFind {
    vector<int> par;
    
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool issame(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x); y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
};

int main() {
    int N, M; cin >> N >> M;
    vector<long long> a(N), b(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    for (int i = 0; i < N; ++i) cin >> b[i];

    UnionFind uf(N);
    for (int i = 0; i < M; ++i) {
        int x, y; cin >> x >> y; --x, --y;
        uf.merge(x, y);
    }

    vector<long long> sa(N, 0), sb(N, 0);
    for (int v = 0; v < N; ++v) {
        int r = uf.root(v);
        sa[r] += a[v], sb[r] += b[v];
    }
    bool res = true;
    for (int v = 0; v < N; ++v) {
        int r = uf.root(v);
        if (sa[r] != sb[r]) res = false;
    }
    if (res) cout << "Yes" << endl;
    else cout << "No" << endl;
}