けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 256 E - Takahashi's Anguish (水色, 500 点)

最近話題の Functional Graph の問題!

問題概要

 1, 2, \dots, N がいる。各人  i には 1 人ずつ嫌いな人  X_{i} がいる。

今、彼らに順番にキャンディーを配る。ただし、各  i について、もし人  i よりも先に  X_{i} にキャンディーを配ると、不満度が  C_{i} だけ加算される。

キャンディーを配る順序を最適化することで、不満度の最小値を求めよ。

制約

  •  2 \le N \le 2 \times 10^{5}

考えたこと

問題文を一目見ただけでも、いくつかの典型要素がある。


  • Functional Graph であること (各頂点につき、辺が 1 本だけ出ているグラフ)
  • 順列を最適化する系の問題であること

後者の要素がある分、他の Functional Graph の問題に比べるとやや難しめになる。なお、人  i を頂点  i とみなし、頂点  i から頂点  X_{i} へとコスト  C_{i} の有向辺を張ってできる Functional Graph を考えている。

さて、Functional Graph といえば「各連結成分につきサイクルが 1 個ずつある」という性質がとても特徴的だ。この問題も、サイクルを元に考察したい。そもそもサイクルがなければ、トポロジカルソート順に配っていけば丸く収まる。サイクルがあると、必ずどこかで不満が発生する。

ここまで考察すれば、もう答えは分かったも同然だ。サイクル中のコストが最小の辺を犠牲にして、キャンディーを配ればいい。コスト最小の辺以外からは、不満を発生させないように配ることが可能である。

 

具体的な解法

与えられた Functional Graph を連結成分ごとに考えていき、各連結成分で発生するサイクル内の辺のコストの最小値を足していけばよい。

実装としては、以下のものが考えられる。いずれも  O(N) O(N\alpha(N)) といった計算量で解ける。


  1. Kruskal 法の要領で、コストの大きい辺から順に、両端の頂点を Union-Find でマージしていき、すでに同じグループにあったら、その辺のコストを足していく
  2. Functional Graph の閉路検出するいつものをやる (頂点  v から辺を辿っていき、すでに訪れた頂点に訪れることで閉路検出する)
  3. 一般的なサイクル検出ライブラリで殴る

解法 1 は「Kruskal 法の要領」であると書いたが、Kruskal 法そのものでもある。というのも、我々のやりたいことは「最大全域森を求める」とも言い換えられるからだ。

各連結成分ごとに最大全域木 (最小全域木ではなく) を求めてあげて、それにふくまれない辺が 1 本だけあるので、そのコストを足していくものと考えればよい。

 

コード

コード 1 (Kruskal 法)

これが一番楽だと思う。

#include <bits/stdc++.h>
using namespace std;

// Union-Find
struct UnionFind {
    // core member
    vector<int> par;

    // constructor
    UnionFind() { }
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
    
    // debug
    friend ostream& operator << (ostream &s, UnionFind uf) {
        map<int, vector<int>> groups;
        for (int i = 0; i < uf.par.size(); ++i) {
            int r = uf.root(i);
            groups[r].push_back(i);
        }
        for (const auto &it : groups) {
            s << "group: ";
            for (auto v : it.second) s << v << " ";
            s << endl;
        }
        return s;
    }
};

int main() {
    using Edge = pair<long long, pair<int,int>>;
    int N;
    cin >> N;
    vector<long long> X(N), C(N);
    for (int i = 0; i < N; ++i) cin >> X[i], --X[i];
    for (int i = 0; i < N; ++i) cin >> C[i];
    
    // C が大きい順にソート
    vector<int> ids(N);
    for (int i = 0; i < N; ++i) ids[i] = i;
    sort(ids.begin(), ids.end(), [&](int i, int j){return C[i] > C[j];});
    
    // Kruskal 法的な要領で
    long long res = 0;
    UnionFind uf(N);
    for (auto i : ids) {
        if (uf.same(i, X[i])) res += C[i];
        uf.merge(i, X[i]);
    }
    cout << res << endl;
}

コード 2 (Functional Graph の閉路検出)

各連結成分ごとにすべて閉路検出しようと思うと、意外と面倒。

#include <bits/stdc++.h>
using namespace std;

// Edge Class
template<class T> struct Edge {
    int from, to;
    T val;
    Edge() : from(-1), to(-1), val(-1) { }
    Edge(int f, int t, T v = -1) : from(f), to(t), val(v) {}
    friend ostream& operator << (ostream& s, const Edge& E) {
        return s << E.from << "->" << E.to;
    }
};

// G[v] := 頂点 v から出ている辺
template<class T> struct CycleDetection {
    // input
    vector<Edge<T>> G;
    
    // intermediate results
    vector<bool> seen, finished;
    vector<int> history;
    
    // constructor
    CycleDetection() { }
    CycleDetection(const vector<Edge<T>> &graph) { init(graph); }
    void init(const vector<Edge<T>> &graph) {
        G = graph;
        seen.assign(G.size(), false);
        finished.assign(G.size(), false);
    }
    
    // return the vertex where cycle is detected
    int search(int v) {
        do {
            seen[v] = true;
            history.push_back(v);
            v = G[v].to;
            if (finished[v]) {
                v = -1;
                break;
            }
        } while (!seen[v]);
        pop_history();
        return v;
    }
    
    // pop history
    void pop_history() {
        while (!history.empty()) {
            int v = history.back();
            finished[v] = true;
            history.pop_back();
        }
    }
    
    // reconstruct
    vector<Edge<T>> reconstruct(int pos) {
        // reconstruct the cycle
        vector<Edge<T>> cycle;
        int v = pos;
        do {
            cycle.push_back(G[v]);
            v = G[v].to;
        } while (v != pos);
        return cycle;
    }
    
    // find cycle, v is the start vertex
    vector<Edge<T>> detect_from_v(int v) {
        int pos = search(v);
        if (pos != -1) return reconstruct(pos);
        else return vector<Edge<T>>();
    }
    
    // find all cycle
    vector<vector<Edge<T>>> detect_all() {
        vector<vector<Edge<T>>> res;
        for (int v = 0; v < (int)G.size(); ++v) {
            if (finished[v]) continue;
            int pos = search(v);
            if (pos == -1) continue;
            const vector<Edge<T>> &cycle = reconstruct(pos);
            if (!cycle.empty()) res.push_back(cycle);
        }
        return res;
    }
};

int main() {
    int N;
    cin >> N;
    vector<long long> X(N), C(N);
    for (int i = 0; i < N; ++i) cin >> X[i], --X[i];
    for (int i = 0; i < N; ++i) cin >> C[i];
    
    // Functional Graph 構築
    vector<Edge<long long>> G(N);
    for (int i = 0; i < N; ++i) G[i] = Edge<long long>(i, X[i], C[i]);
    
    // 閉路検出
    using Cycle = vector<Edge<long long>>;
    CycleDetection<long long> cd(G);
    const vector<Cycle> &cycles = cd.detect_all();
    
    // 集計
    long long res = 0;
    for (const auto &cycle : cycles) {
        long long min_cost = 1LL<<60;
        for (const auto &e : cycle) min_cost = min(min_cost, e.val);
        res += min_cost;
    }
    cout << res << endl;
}

コード 3 (一般的なサイクル検出)

最後に、完全にライブラリで殴る方法。

#include <bits/stdc++.h>
using namespace std;

// Edge Class
template<class T> struct Edge {
    int from, to;
    T val;
    Edge() : from(-1), to(-1), val(-1) { }
    Edge(int f, int t, T v = -1) : from(f), to(t), val(v) {}
    friend ostream& operator << (ostream& s, const Edge& E) {
        return s << E.from << "->" << E.to;
    }
};

// graph class
template<class T> struct Graph {
    vector<vector<Edge<T>>> list;
    
    Graph(int n = 0) : list(n) { }
    void init(int n = 0) {
        list.assign(n, vector<Edge<T>>());
    }
    const vector<Edge<T>> &operator [] (int i) const { return list[i]; }
    const size_t size() const { return list.size(); }
        
    void add_edge(int from, int to, T val = -1) {
        list[from].push_back(Edge(from, to, val));
    }
    
    void add_bidirected_edge(int from, int to, T val = -1) {
        list[from].push_back(Edge(from, to, val));
        list[to].push_back(Edge(to, from, val));
    }

    friend ostream &operator << (ostream &s, const Graph &G) {
        s << endl;
        for (int i = 0; i < G.size(); ++i) {
            s << i << " -> ";
            for (const auto &e : G[i]) s << e.to << " ";
            s << endl;
        }
        return s;
    }
};

// cycle detection
template<class T> struct CycleDetection {
    // input
    Graph<T> G;
    
    // intermediate results
    vector<bool> seen, finished;
    vector<Edge<T>> history;
    
    // constructor
    CycleDetection() { }
    CycleDetection(const Graph<T> &graph) { init(graph); }
    void init(const Graph<T> &graph) {
        G = graph;
        seen.assign(G.size(), false);
        finished.assign(G.size(), false);
    }
    
    // dfs
    // return the vertex where cycle is detected
    int dfs(int v, const Edge<T> &e, bool is_prohibit_reverse = true) {
        seen[v] = true;
        for (const Edge<T> &e2 : G[v]) {
            if (is_prohibit_reverse && e2.to == e.from) continue;
            if (finished[e2.to]) continue;

            // detect cycle
            if (seen[e2.to] && !finished[e2.to]) {
                history.push_back(e2);
                finished[v] = true;
                return e2.to;
            }

            history.push_back(e2);
            int pos = dfs(e2.to, e2, is_prohibit_reverse);
            if (pos != -1) {
                finished[v] = true;
                return pos;
            }
            history.pop_back();
        }
        finished[v] = true;
        return -1;
    }
    
    // reconstruct
    vector<Edge<T>> reconstruct(int pos) {
        vector<Edge<T>> cycle;
        while (!history.empty()) {
            const Edge<T> &e = history.back();
            cycle.push_back(e);
            history.pop_back();
            if (e.from == pos) break;
        }
        reverse(cycle.begin(), cycle.end());
        return cycle;
    }
    
    // find cycle, v is the start vertex
    vector<Edge<T>> detect_from_v(int v, bool is_prohibit_reverse = true) {
        history.clear();
        int pos = dfs(v, Edge<T>(), is_prohibit_reverse);
        if (pos != -1) return reconstruct(pos);
        else return vector<Edge<T>>();
    }
    
    // find cycle
    vector<Edge<T>> detect(bool is_prohibit_reverse = true) {
        int pos = -1;
        for (int v = 0; v < (int)G.size() && pos == -1; ++v) {
            if (seen[v]) continue;
            history.clear();
            pos = dfs(v, Edge<T>(), is_prohibit_reverse);
            if (pos != -1) return reconstruct(pos);
        }
        return vector<Edge<T>>();
    }
};

int main() {
    int N;
    cin >> N;
    Graph<long long> G(N);
    vector<long long> X(N), C(N);
    for (int i = 0; i < N; ++i) cin >> X[i];
    for (int i = 0; i < N; ++i) cin >> C[i];
    for (int i = 0; i < N; ++i) G.add_edge(i, X[i]-1, C[i]);

    long long res = 0;
    CycleDetection<long long> cd(G);
    for (int v = 0; v < N; ++v) {
        if (cd.seen[v]) continue;
        
        // 頂点 v から探索を開始して見つかるサイクルを検出する
        const auto &cycle = cd.detect_from_v(v, false);
        if (cycle.empty()) continue;
        long long minv = 1LL<<60;
        for (const auto &e : cycle) minv = min(minv, e.val);
        res += minv;
    }
    cout << res << endl;
}