けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 167 D - Good Permutation (黄色, 700 点)

モノグサでバチャやって、なんとか通した!

問題概要

 1, 2, \dots, N の順列  P が与えられる。この順列に対して「2 個の要素を選んで swap する」という操作を実行して、

「順列から誘導される Functional Graph の連結成分の個数が 1 個」

という状態を実現したい。そのための最小回数を  M としたとき、 M 回の操作で実現される、上記の状態を実現した順列のうち、辞書順最小の物を求めよ。

(マルチテストケース)

制約

  •  2 \le N \le 10^{5}

考えたこと

順列  P から誘導される Functional Graph の連結成分数を  A とすると、最小で  A-1 回の操作で実現できることは容易にわかる。

できあがる辞書順最小のものを求めるためには、


 x = 1, 2, \dots, N に対して

  • 全体が 1 つのサイクルになっていたら処理を終了する
  • 頂点  x を含むサイクルに含まれない頂点番号の最小値を  m とする
  •  P_{x} \gt m であるか、 x を含むサイクルの頂点番号の最大値が  x であるとき: x の次の頂点が  m となるようにサイクルを繋ぎかえる

というようにしたい。

難しいのは  m を求めるところで、僕は次のようにした。

  • Functional Graph の各サイクルを UnionFind で管理する
  • UnionFind 上で根となる頂点以外には値 INF を持たせて、根となる頂点にはそのサイクルに含まれる頂点番号の最小値を持たせる
  • さらに、これらの値を RMQ で管理する

僕の方法では計算量は  O(N \log N) となった。

コード

#include <bits/stdc++.h>
using namespace std;

// Segment Tree
template<class Monoid> struct SegmentTree {
    using Func = function<Monoid(Monoid, Monoid)>;

    // core member
    int N;
    Func OP;
    Monoid IDENTITY;
    
    // inner data
    int log, offset;
    vector<Monoid> dat;

    // constructor
    SegmentTree() {}
    SegmentTree(int n, const Func &op, const Monoid &identity) {
        init(n, op, identity);
    }
    SegmentTree(const vector<Monoid> &v, const Func &op, const Monoid &identity) {
        init(v, op, identity);
    }
    void init(int n, const Func &op, const Monoid &identity) {
        N = n;
        OP = op;
        IDENTITY = identity;
        log = 0, offset = 1;
        while (offset < N) ++log, offset <<= 1;
        dat.assign(offset * 2, IDENTITY);
    }
    void init(const vector<Monoid> &v, const Func &op, const Monoid &identity) {
        init((int)v.size(), op, identity);
        build(v);
    }
    void pull(int k) {
        dat[k] = OP(dat[k * 2], dat[k * 2 + 1]);
    }
    void build(const vector<Monoid> &v) {
        assert(N == (int)v.size());
        for (int i = 0; i < N; ++i) dat[i + offset] = v[i];
        for (int k = offset - 1; k > 0; --k) pull(k);
    }
    int size() const {
        return N;
    }
    Monoid operator [] (int i) const {
        return dat[i + offset];
    }
    
    // update A[i], i is 0-indexed, O(log N)
    void set(int i, const Monoid &v) {
        assert(0 <= i && i < N);
        int k = i + offset;
        dat[k] = v;
        while (k >>= 1) pull(k);
    }
    
    // get [l, r), l and r are 0-indexed, O(log N)
    Monoid prod(int l, int r) {
        assert(0 <= l && l <= r && r <= N);
        Monoid val_left = IDENTITY, val_right = IDENTITY;
        l += offset, r += offset;
        for (; l < r; l >>= 1, r >>= 1) {
            if (l & 1) val_left = OP(val_left, dat[l++]);
            if (r & 1) val_right = OP(dat[--r], val_right);
        }
        return OP(val_left, val_right);
    }
    Monoid all_prod() {
        return dat[1];
    }
    
    // get max r such that f(v) = True (v = prod(l, r)), O(log N)
    // f(IDENTITY) need to be True
    int max_right(const function<bool(Monoid)> f, int l = 0) {
        if (l == N) return N;
        l += offset;
        Monoid sum = IDENTITY;
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(OP(sum, dat[l]))) {
                while (l < offset) {
                    l = l * 2;
                    if (f(OP(sum, dat[l]))) {
                        sum = OP(sum, dat[l]);
                        ++l;
                    }
                }
                return l - offset;
            }
            sum = OP(sum, dat[l]);
            ++l;
        } while ((l & -l) != l);  // stop if l = 2^e
        return N;
    }

    // get min l that f(get(l, r)) = True (0-indexed), O(log N)
    // f(IDENTITY) need to be True
    int min_left(const function<bool(Monoid)> f, int r = -1) {
        if (r == 0) return 0;
        if (r == -1) r = N;
        r += offset;
        Monoid sum = IDENTITY;
        do {
            --r;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(OP(dat[r], sum))) {
                while (r < offset) {
                    r = r * 2 + 1;
                    if (f(OP(dat[r], sum))) {
                        sum = OP(dat[r], sum);
                        --r;
                    }
                }
                return r + 1 - offset;
            }
            sum = OP(dat[r], sum);
        } while ((r & -r) != r);
        return 0;
    }
    
    // debug
    friend ostream& operator << (ostream &s, const SegmentTree &seg) {
        for (int i = 0; i < (int)seg.size(); ++i) {
            s << seg[i];
            if (i != (int)seg.size() - 1) s << " ";
        }
        return s;
    }
};

// Union-Find
struct UnionFind {
    // core member
    vector<int> par, nex;

    // constructor
    UnionFind() { }
    UnionFind(int N) : par(N, -1), nex(N) {
        init(N);
    }
    void init(int N) {
        par.assign(N, -1);
        nex.resize(N);
        for (int i = 0; i < N; ++i) nex[i] = i;
    }
    
    // core methods
    int root(int x) {
        if (par[x] < 0) return x;
        else return par[x] = root(par[x]);
    }
    
    bool same(int x, int y) {
        return root(x) == root(y);
    }
    
    bool merge(int x, int y) {
        x = root(x), y = root(y);
        if (x == y) return false;
        if (par[x] > par[y]) swap(x, y); // merge technique
        par[x] += par[y];
        par[y] = x;
        swap(nex[x], nex[y]);
        return true;
    }
    
    int size(int x) {
        return -par[root(x)];
    }
    
    // get group
    vector<int> group(int x) {
        vector<int> res({x});
        while (nex[res.back()] != x) res.push_back(nex[res.back()]);
        return res;
    }
    vector<vector<int>> groups() {
        vector<vector<int>> member(par.size());
        for (int v = 0; v < (int)par.size(); ++v) {
            member[root(v)].push_back(v);
        }
        vector<vector<int>> res;
        for (int v = 0; v < (int)par.size(); ++v) {
            if (!member[v].empty()) res.push_back(member[v]);
        }
        return res;
    }
    
    // debug
    friend ostream& operator << (ostream &s, UnionFind uf) {
        const vector<vector<int>> &gs = uf.groups();
        for (const vector<int> &g : gs) {
            s << "group: ";
            for (int v : g) s << v << " ";
            s << endl;
        }
        return s;
    }
};

void solve() {
    int N;
    cin >> N;
    vector<int> P(N), IP(N);
    for (int i = 0; i < N; ++i) {
        cin >> P[i];
        --P[i];
        IP[P[i]] = i;
    }
    
    set<int> S;
    for (int i = 0; i < N; ++i) S.insert(i);
    
    int INF = N;
    UnionFind uf(N);
    SegmentTree<int> seg(N, [&](int a, int b){ return min(a, b); }, INF);
    SegmentTree<int> seg2(N, [&](int a, int b){ return max(a, b); }, -INF);
    for (int i = 0; i < N; ++i) {
        seg.set(i, i);
        seg2.set(i, i);
    }
    
    auto merge = [&](int x, int y) -> void {
        if (uf.same(x, y)) return;
        x = uf.root(x), y = uf.root(y);
        uf.merge(x, y);
        int minv = min(seg.prod(x, x+1), seg.prod(y, y+1));
        int maxv = max(seg2.prod(x, x+1), seg2.prod(y, y+1));
        if (uf.root(x) == x) {
            seg.set(x, minv), seg.set(y, INF);
            seg2.set(x, maxv), seg2.set(y, -INF);
        } else {
            seg.set(x, INF), seg.set(y, minv);
            seg2.set(x, -INF), seg2.set(y, maxv);
        }
    };
    
    auto get_min = [&](int x) -> int {
        x = uf.root(x);
        int left = seg.prod(0, x);
        int right = seg.prod(x+1, N);
        return min(left, right);
    };
    auto get_max = [&](int x) -> int {
        x = uf.root(x);
        return seg2.prod(x, x+1);
    };
    
    auto swap2 = [&](int x, int y) -> void {
        merge(x, y);
        int x2 = P[x], y2 = P[y];
        swap(P[x], P[y]);
        swap(IP[x2], IP[y2]);
    };
    
    for (int x = 0; x < N; ++x) {
        merge(x, P[x]);
    }

    for (int x = 0; x < N; ++x) {
        int m = get_min(x);
        if (m == INF) break;
        if (P[x] > m || x == get_max(x)) {
            
            int y = IP[m];
            swap2(x, y);
        }
    }
    
    for (int i = 0; i < N; ++i) {
        cout << P[i]+1 << " ";
    }
    cout << endl;
}

int main() {
    int T;
    cin >> T;
    while (T--) solve();
}