けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #609 (Div. 1) C. K Integers (R2300)

とてもこどふぉっぽい問題だと思った!!!こういうのを得意になるぞー!!!

問題へのリンク

問題概要

 1, 2, \dots, N の順列が与えられる。各  k = 1, 2, \dots, N に対して、以下の問いに答えよ。

  • 順列の隣り合う 2 要素を swap して、順列のどこかの場所で  1, 2, \dots, k がこの順に連続で並んでいる状態にしたい
  • それを実現する最小 swap 回数を求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

 k に関する問題は以下の二部パートにわかれそう

  •  1, 2, \dots, k を連続した場所に集める
  • それらの転倒数を求める

後者については、 1, 2, \dots, k-1 の部分についての転倒数に、 k を足すことで転倒数がいくつ増えるのかを BIT で求めることができる。これは「いつもの転倒数の求め方」と一緒なので、難しくない。

前者については、 1, 2, \dots, k の存在する index を  b_{1}, \dots, b_{k} としたとき、

  •  b_{1}, \dots, b_{k} のメディアンについてはその位置を固定
  • その前後をメディアンのところに持ってくる

とするのが最小になる。このことは、よくある「絶対値の和を最小にするのはメディアンのとき」というのとまったく同様にして証明できる。

僕はメディアンの位置を求めて、その前後に集めるのに必要な操作回数を求める作業は priority_queue を 2 つ使う実装をした。でもよく考えたら、転倒数を求めるのに使った BIT を使いまわせばよかった。

いずれにしても計算量は  O(N\log{N})

メディアンを求めるのに、BIT を使い回す方法

#include <iostream>
#include <vector>
#include <queue>
using namespace std;
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
 
template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                     // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.sssign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }

    /* k-th number (k is 0-indexed) */
    int get(long long k) {
        ++k;
        int res = 0;
        int N = 1; while (N < (int)dat.size()) N *= 2;
        for (int i = N / 2; i > 0; i /= 2) {
            if (res + i < (int)dat.size() && dat[res + i] < k) {
                k = k - dat[res + i];
                res = res + i;
            }
        }
        return res + 1;
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};
 
int N;
vector<int> p, ip; // ip: 逆順列
 
vector<long long> solve() {
    vector<long long> res(N, 0);
    BIT<long long> bit(N+10), bit2(N+10);
    long long tensum = 0;
    for (int i = 0; i < N; ++i) {
        // 追加転倒数
        long long add_tentou = bit.sum(ip[i]+1, N + 5);
        tensum += add_tentou;
        
        // 情報更新
        bit.add(ip[i]+1, 1);
        bit2.add(ip[i]+1, ip[i]);
        
        // メディアン
        long long medi = bit.get(i/2) - 1; // 0-index に

        // メディアンに集める量
        long long left = i/2;
        long long right = i - i/2;
        long long left_sum = bit2.sum(1, medi + 1);
        long long right_sum = bit2.sum(medi + 2, N + 5);
        long long left_move = ((medi - 1) + (medi - left)) * (i/2) / 2 - left_sum;
        long long right_move = right_sum - ((medi + 1) + (medi + right)) * right / 2;
        res[i] = tensum + left_move + right_move;
    }
 
    return res;
}
 
int main() {
    while (scanf("%d", &N) != EOF) {
        p.resize(N); ip.resize(N);
        for (int i = 0; i < N; ++i) {
            scanf("%d", &p[i]), --p[i];
            ip[p[i]] = i;
        }
        auto res = solve();
        for (int i = 0; i < N; ++i) {
            if (i) printf(" ");
            printf("%lld", res[i]);
        }
        printf("\n");
    }
}

メディアンを求めるのに、priority_queue を 2 つ使った方法

#include <iostream>
#include <vector>
#include <queue>
using namespace std;
 
template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                     // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.sssign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }

    /* k-th number (k is 0-indexed) */
    int get(long long k) {
        ++k;
        int res = 0;
        int N = 1; while (N < (int)dat.size()) N *= 2;
        for (int i = N / 2; i > 0; i /= 2) {
            if (res + i < (int)dat.size() && dat[res + i] < k) {
                k = k - dat[res + i];
                res = res + i;
            }
        }
        return res + 1;
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};
 
 
int N;
vector<int> p, ip; // ip: 逆順列
 
vector<long long> solve() {
    vector<long long> tentou(N, 0);
    BIT<long long> bit(N+10);
    for (int i = 0; i < N; ++i) {
        long long tmp = bit.sum(ip[i]+1, N+9);
        if (i > 0) tentou[i] = tentou[i-1] + tmp;
        bit.add(ip[i]+1, 1);
    }
 
    vector<long long> res = tentou;
    priority_queue<long long> zen;
    priority_queue<long long, vector<long long>, greater<long long> > kou;
    long long zensum = 0, kousum = 0;
 
    for (int i = 0; i < N; ++i) {
        long long add = ip[i];
 
        // push to zen
        if (zen.size() == kou.size()) {
            if (kou.empty()) zen.push(add), zensum += add;
            else {
                long long mi = kou.top();
                if (add > mi) {
                    kou.pop(); kousum -= mi;
                    zen.push(mi); zensum += mi;
                    kou.push(add); kousum += add;
                }
                else {
                    zen.push(add); zensum += add;
                }
            }
        }
        // push to kou
        else {
            if (zen.empty()) kou.push(add), kousum += add;
            else {
                long long ma = zen.top();
                if (add < ma) {
                    zen.pop(); zensum -= ma;
                    kou.push(ma); kousum += ma;
                    zen.push(add); zensum += add;
                }
                else {
                    kou.push(add); kousum += add;
                }
            }
        }
        long long median = zen.top();
        long long zennum = (long long)zen.size() - 1;
        long long kounum = kou.size();
        long long zenkaisa = ((median - 1) + (median - zennum)) * zennum / 2;
        long long koukaisa = ((median + 1) + (median + kounum)) * kounum / 2;
 
        long long zenadd = zenkaisa - (zensum - median);
        long long kouadd = kousum - koukaisa;
        res[i] += zenadd + kouadd;
    }
    return res;
}
 
int main() {
    while (scanf("%d", &N) != EOF) {
        p.resize(N); ip.resize(N);
        for (int i = 0; i < N; ++i) {
            scanf("%d", &p[i]), --p[i];
            ip[p[i]] = i;
        }
        auto res = solve();
        for (int i = 0; i < N; ++i) {
            if (i) printf(" ");
            printf("%lld", res[i]);
        }
        printf("\n");
    }
}