けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 188 C - Honest or Liar or Confused (3D, 黄色, 700 点)

面白かった!!

問題概要

 N 人がいて、それぞれ正直者であるか、嘘つきであるかのいずれかである。また、各人は混乱していないか、混乱しているかのいずれかである。

  • 混乱していない正直者は、常に正しいことをいう
  • 混乱している正直者は、常に間違っていることをいう
  • 混乱していない嘘つきは、常に間違っていることをいう
  • 混乱している嘘つきは、常に正しいことをいう

ここで、 M 個の証言がある。 i 番目の証言は、人  A_{i} によるもので、「人  B_{i} C_{i} (0 のとき正直者、1 のとき嘘つき) だ」というものであった。

証言すべてが整合するように、各人に「混乱していない」「混乱している」を割り当てる方法を 1 つ求めよ。

制約

  •  N, M \le 2 \times 10^{5}

考えたこと

条件がとても複雑なのでシンプルに考えたい。思うに

  • 「正直者」と「嘘つき」を入れ替えると、証言の正しさも入れ替わる
  • 「混乱していない」と「混乱している」を入れ替えると、証言の正さも入れ替わる

ということから、条件は XOR で記述できるはずだと考えた。実際に、試行錯誤の末に、次のように言い換えられた。


 A による証言「 B c だ」を考える。

 A, B が正直者か嘘つきかを表す変数を  a, b (正直者:0、嘘つき:1) とし、 A 混乱しているかどうかを表す変数を  p (混乱していない:0、混乱している:1) とする。このとき、

 a XOR  b XOR  p =  c

が成り立つ


よって、各人  i に対して、正直者かどうかを表す変数と、混乱しているかどうかを表す変数を定義すると、 2N 変数の連立一次方程式を解く問題となる。

変数消去して差分制約系にする

上の式は 3 変数についての関係式になっていて扱いづらいので、2 変数の式にしたい。そこで、混乱しているかどうかを表す変数を消去することにした。

つまり、各人  i について、人  i が混乱しているかどうかを表す値が存在するための条件を考察することにした。そうすると、下図のようになる。

よって、各人が混乱しているかどうかが定まるための必要十分条件を「差分制約系」でかけた。それは、ポテンシャル付き Union-Find などを用いて解くことができる。

計算量は  O(M \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// Weighted Union-Find (T: the type of v[0], v[1], ..., v[N-1])
template<class T> struct WeightedUnionFind {
    // core member
    vector<int> par;
    vector<T> weight;

    // constructor
    WeightedUnionFind() { }
    WeightedUnionFind(int N, T zero = 0) : par(N, -1), weight(N, zero) {}
    void init(int N, T zero = 0) {
        par.assign(N, -1);
        weight.assign(N, zero);
    }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else {
            int r = root(par[x]);
            weight[x] += weight[par[x]];
            return par[x] = r;
        }
    }
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    int size(int x) {
        return -par[root(x)];
    }
    
    // v[y] - v[x] = w
    bool merge(int x, int y, T w) {
        w += get_weight(x), w -= get_weight(y);
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y), w = -w; // merge technique
        par[x] += par[y];
        par[y] = x;
        weight[y] = w;
        return true;
    }
    
    // get v[x]
    T get_weight(int x) {
        root(x);
        return weight[x];
    }
    
    // get v[y] - v[x]
    T get_diff(int x, int y) {
        return get_weight(y) - get_weight(x);
    }
    
    // get groups
    vector<vector<int>> groups() {
        vector<vector<int>> member(par.size());
        for (int v = 0; v < (int)par.size(); ++v) {
            member[root(v)].push_back(v);
        }
        vector<vector<int>> res;
        for (int v = 0; v < (int)par.size(); ++v) {
            if (!member[v].empty()) res.push_back(member[v]);
        }
        return res;
    }
    
    // debug
    friend ostream& operator << (ostream &s, WeightedUnionFind uf) {
        const vector<vector<int>> &gs = uf.groups();
        for (const vector<int> &g : gs) {
            s << "group: ";
            for (int v : g) s << v << "(" << uf.get_weight(v) << ") ";
            s << endl;
        }
        return s;
    }
};

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

using mint2 = Fp<2>;

int main() {
    int N, M, a, b, c;
    cin >> N >> M;
    using Edge = pair<int,mint2>;
    vector<vector<Edge>> G(N);
    for (int i = 0; i < M; i++) {
        cin >> a >> b >> c, a--, b--;
        G[a].emplace_back(b, mint2(c));
    }

    WeightedUnionFind<mint2> uf(N);
    for (int v = 0; v < N; v++) {
        if (G[v].size() < 2) continue;
        auto [b1, c1] = G[v][0];
        for (int i = 1; i < G[v].size(); i++) {
            auto [bi, ci] = G[v][i];

            if (!uf.same(b1, bi)) uf.merge(b1, bi, c1 + ci);
            else if (uf.get_diff(b1, bi) != c1 + ci) {
                cout << -1 << endl;
                return 0;
            }
        }
    }

    vector<mint2> res(N, 0);
    for (int v = 0; v < N; v++) {
        if (!G[v].empty()) {
            auto [w, c] = G[v][0];
            mint2 a = uf.get_weight(v), b = uf.get_weight(w);
            res[v] = a + b + c;
        }
        cout << res[v];
    }
    cout << endl;
}