けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2726 Black Company (JAG 夏合宿 2015 day4-J) (500 点)

素直に考えるとグラフの辺数が  O(N^{2}) のオーダーになってしまうので、いかに削減するかを考える問題だった

問題概要

頂点数  N、辺数  M の単純無向グラフが与えられる。各頂点  v には値  c_{v} が振られている。今、各頂点にスコア  p_{v} を割り振りたい。ただし以下の条件を満たす必要がある。

  •  p_{i} は正の整数
  • 距離が 2 以下の任意の 2 頂点  u, v に対して、
    •  c_{u} \gt c_{v} ならば  p_{u} \gt p_{v} でなければならない
    •  p_{u} \gt p_{v} ならば  c_{u} \gt c_{v} でなければならない

各頂点の  p_{v} の総和として考えられる最小値を求めよ。

制約

  •  1 \le N \le 10^{5}
  •  0 \le M \le 2 \times 10^{5}
  •  1 \le c_{v} \le 10^{5}

考えたこと

まずはナイーブに考えてみる。大前提として、元のグラフ上で距離が 2 以下の 2 頂点  u, v に対しては、 c_{u} = c_{v} である場合には 1 点に縮約してしまうことにする (それらは値が等しくなければならないため)。そうしてできる新たな頂点集合を  V' とする。

そして、以下のような新たな有向グラフ  G' を作る (頂点集合は  V')。グラフ  G' の二頂点  u, v に対して以下のように辺を張る

  •  u, v 間の距離が 3 以上のときは辺を張らない
  •  u, v 間の距離が 2 以下のときは、 c の値が小さい方から大きい方へ辺を張る
    • 等しい部分は縮約しているので大小関係はどちらかに定まる

こうしてできたグラフは DAG になる。この DAG 上で、

  • ソースの値は 1
  • その他の頂点は、いずれかのソースからの最長路長を割り振る

として、その総和を求めれば答えとなる。しかし...このグラフ  G' は辺数が最悪  O(N^{2}) となるのでこのままでは TLE するし MLE もしてしまう。

考えるべき辺数を減らす

このように辺の本数が  O(N^{2}) となってしまう状況においては

「本質的に重要な辺のみを抽出することで辺の本数を  O(N) にする」

という考え方が効くかと思う。そしてそのための考察手段としては、

  • 三角不等式
  • 分枝限定法

などなど、色々考えられそう。今回は三角不等式に則って考えてみる。まず、次のことがいえる。


3 頂点  u, v, w において、 c_{u} \lt c_{v} \lt c_{w} であるとき、辺  (u, v) と辺  (v, w) のみを残せばよく、辺  (u, w) は不要である


具体的には、

  • 元のグラフの有向辺  (u, v) については対応する辺を張る
  • 元のグラフで頂点  v を始点とする任意の辺  (v, w_{k} に対し、各頂点  w_{k} c_{w_{k}} が小さい順にソートしたときに、隣接する頂点間にのみ、辺を張っていく
  • 元のグラフで頂点  v を終点とする任意の辺  (w_{k}, v に対し、各頂点  w_{k} c_{w_{k}} が小さい順にソートしたときに、隣接する頂点間にのみ、辺を張っていく

という風に作った DAG 上で DP すればよさそう。こうしてできたグラフの辺数は  O(M) で抑えられる。

よって計算量は  O(N \log N + M) に改善された。

コード

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

struct UnionFind {
    vector<int> par;
    
    UnionFind(int n) : par(n, -1) { }
    void init(int n) { par.assign(n, -1); }
    
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool issame(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x); y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
};

using pint = pair<int,int>;
using Graph = vector<vector<int>>;
int main() {
    int N, M;
    cin >> N;
    vector<int> c(N);
    for (int i = 0; i < N; ++i) cin >> c[i];
    cin >> M;
    UnionFind uf(N);
    vector<pint> edges(M);
    for (int i = 0; i < M; ++i) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        edges[i] = pint(u, v);
        if (c[u] == c[v]) uf.merge(u, v);
    }
    Graph nex(N), pre(N);
    for (auto e : edges) {
        int u = e.first, v = e.second;
        if (c[u] < c[v]) nex[u].push_back(v), pre[v].push_back(u);
        else if (c[u] > c[v]) nex[v].push_back(u), pre[u].push_back(v);
    }
    for (int v = 0; v < N; ++v) {
        sort(nex[v].begin(), nex[v].end(), [&](int i, int j) {return c[i] < c[j];});
        sort(pre[v].begin(), pre[v].end(), [&](int i, int j) {return c[i] < c[j];});
        for (int i = 0; i+1 < nex[v].size(); ++i) {
            if (c[nex[v][i]] == c[nex[v][i+1]]) uf.merge(nex[v][i], nex[v][i+1]);
        }
        for (int i = 0; i+1 < pre[v].size(); ++i) {
            if (c[pre[v][i]] == c[pre[v][i+1]]) uf.merge(pre[v][i], pre[v][i+1]);
        }
    }

    Graph G(N);
    for (int v = 0; v < N; ++v) {
        if (!nex[v].empty()) G[uf.root(v)].push_back(uf.root(nex[v][0]));
        for (int i = 0; i+1 < nex[v].size(); ++i) {
            if (c[nex[v][i]] < c[nex[v][i+1]]) {
                G[uf.root(nex[v][i])].push_back(uf.root(nex[v][i+1]));
            }
        }
        for (int i = 0; i+1 < pre[v].size(); ++i) {
            if (c[pre[v][i]] < c[pre[v][i+1]]) {
                G[uf.root(pre[v][i])].push_back(uf.root(pre[v][i+1]));
            }
        }
    }
    vector<long long> dp(N, 0);
    vector<int> nodes(N);
    iota(nodes.begin(), nodes.end(), 0);
    sort(nodes.begin(), nodes.end(), [&](int i, int j) {return c[i] < c[j];});
    for (auto v : nodes) {
        int r = uf.root(v);
        if (v != r) continue;
        chmax(dp[v], 1LL);
        for (auto to : G[v]) chmax(dp[to], dp[r] + 1);
    }
    long long res = 0;
    for (auto v : nodes) res += dp[v] * uf.size(v);
    cout << res << endl;
}