けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 328 F - Good Set Query (水色, 525 点)

重み付き Union-Find そのもの。もしくは、列挙可能 Union-Find 使ってマージテクでも。

問題概要

 N 個の整数値  x_{1}, x_{2}, \dots, x_{N} に関する制約条件が  Q 個与えられる。

 i 番目の制約条件では 3 つの整数の組  (a, b, d) が与えられ、 x_{a} - x_{b} = d という形をしている。ここで、次のクエリに答えよ。

  • すでにある制約条件と、この  i 番目の制約条件「 x_{a} - x_{b} = d 」とが矛盾するかどうかを判定せよ

矛盾しない場合には  i を出力し、矛盾する場合には何も出力しないこととする。

制約

  •  1 \le N, Q \le 2 \times 10^{5}

解法 (1):重み付き Union-Find

重み付き Union-Find が使える問題は過去にも色々あった!

drken1215.hatenablog.com

でも、その多くは DFS や BFS でも代用できた。今回は「矛盾があるかどうかの判定を毎回オンラインに実行する」必要があるため、重み付き Union-Find がドンピシャで当てはまる!!

重み付き Union-Find とは、ただのグループ分けを管理する Union-Find ではなく、各要素に重みを持たせたものと言える。そして、同じグループに属する要素については「重みの差分」が分かるように重みを管理する。詳細は次の記事にて!

qiita.com

計算量は  O(N + Q\alpha(N)) となる。

コード (解法 (1))

#include <bits/stdc++.h>
using namespace std;


// Weighted Union-Find (T: the type of v[0], v[1], ..., v[N-1])
template<class T> struct WeightedUnionFind {
    // core member
    vector<int> par;
    vector<T> weight;

    // constructor
    WeightedUnionFind() { }
    WeightedUnionFind(int N, T zero = 0) : par(N, -1), weight(N, zero) {}
    void init(int N, T zero = 0) {
        par.assign(N, -1);
        weight.assign(N, zero);
    }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else {
            int r = root(par[x]);
            weight[x] += weight[par[x]];
            return par[x] = r;
        }
    }
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    int size(int x) {
        return -par[root(x)];
    }
    
    // v[y] - v[x] = w
    bool merge(int x, int y, T w) {
        w += get_weight(x), w -= get_weight(y);
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y), w = -w; // merge technique
        par[x] += par[y];
        par[y] = x;
        weight[y] = w;
        return true;
    }
    
    // get v[x]
    T get_weight(int x) {
        root(x);
        return weight[x];
    }
    
    // get v[y] - v[x]
    T get_diff(int x, int y) {
        return get_weight(y) - get_weight(x);
    }
    
    // get groups
    vector<vector<int>> groups() {
        vector<vector<int>> member(par.size());
        for (int v = 0; v < (int)par.size(); ++v) {
            member[root(v)].push_back(v);
        }
        vector<vector<int>> res;
        for (int v = 0; v < (int)par.size(); ++v) {
            if (!member[v].empty()) res.push_back(member[v]);
        }
        return res;
    }
    
    // debug
    friend ostream& operator << (ostream &s, WeightedUnionFind uf) {
        const vector<vector<int>> &gs = uf.groups();
        for (const vector<int> &g : gs) {
            s << "group: ";
            for (int v : g) s << v << "(" << uf.get_weight(v) << ") ";
            s << endl;
        }
        return s;
    }
};


// ABC 328 F
void ABC_328_F() {
    int N, Q;
    cin >> N >> Q;
    WeightedUnionFind<long long> uf(N);
    for (int i = 0; i < Q; ++i) {
        int a, b, d;
        cin >> a >> b >> d;
        --a, --b;
        
        bool good = true;
        if (!uf.same(a, b)) {
            // x[a] - x[b] = d となるように
            uf.merge(b, a, d);
        } else {
            // x[a] - x[b] = d でないとき、ダメ
            if (uf.get_diff(b, a) != d) good = false;
        }
        
        if (good) cout << i+1 << " ";
    }
    cout << endl;
}

int main() {
    ABC_328_F();
}

 

解法 (2):列挙可能 Union-Find + マージテク

列挙可能 Union-Find とは、要素  x を含む根付き木のメンバーたちを返せるようにしたもの。

制約条件  x_{a} - x_{b} = d によって要素  a, b をマージする際に、根付き木のサイズが小さい方の各頂点について、 x の値を調整する。そうすれば、マージテクによって、全体の計算量が  O(N + Q \log N) となる。

コード (解法 (2))

#include <bits/stdc++.h>
using namespace std;


// 列挙可能 Union-Find
struct UnionFind {
    // core member
    vector<int> par, nex;

    // constructor
    UnionFind() { }
    UnionFind(int N) : par(N, -1), nex(N) {
        init(N);
    }
    void init(int N) {
        par.assign(N, -1);
        nex.resize(N);
        for (int i = 0; i < N; ++i) nex[i] = i;
    }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        swap(nex[x], nex[y]);
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
    
    // get group
    vector<int> group(int x) {
        vector<int> res({x});
        while (nex[res.back()] != x) res.push_back(nex[res.back()]);
        return res;
    }
    vector<vector<int>> groups() {
        vector<vector<int>> member(par.size());
        for (int v = 0; v < (int)par.size(); ++v) {
            member[root(v)].push_back(v);
        }
        vector<vector<int>> res;
        for (int v = 0; v < (int)par.size(); ++v) {
            if (!member[v].empty()) res.push_back(member[v]);
        }
        return res;
    }
    
    // debug
    friend ostream& operator << (ostream &s, UnionFind uf) {
        const vector<vector<int>> &gs = uf.groups();
        for (const vector<int> &g : gs) {
            s << "group: ";
            for (int v : g) s << v << " ";
            s << endl;
        }
        return s;
    }
};


// ABC 328 F
void ABC_328_F() {
    int N, Q;
    cin >> N >> Q;
    UnionFind uf(N);
    vector<long long> x(N, 0);  // x value
    for (int i = 0; i < Q; ++i) {
        int a, b, d;
        cin >> a >> b >> d;
        --a, --b;
        
        bool good = true;
        if (uf.same(a, b)) {
            // x[a] - x[b] = d でないとき、ダメ
            if (x[a] - x[b] != d) good = false;
        } else {
            // マージテクにより、size(a) < size(b) となるようにする
            if (uf.size(a) > uf.size(b)) {
                swap(a, b);
                d = -d;
            }
            
            // a を含むグループのメンバーたちに足す値を求める
            long long add = (x[b] + d) - x[a];
            
            // x の値を調整して、マージする
            vector<int> ids = uf.group(a);
            for (int id : ids) x[id] += add;
            uf.merge(a, b);
        }

        // good ならば出力する
        if (good) cout << i+1 << " ";
    }
    cout << endl;
}


int main() {
    ABC_328_F();
}