けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

ACL Contest 1 E - Shuffle Window (橙色, 900 点)

コンテスト中に  O(N^{2}) まで導いておきながら、最後詰めきれなかったのは反省。

問題へのリンク

問題概要

 1, 2, \dots, N の順列  p と、2 以上の整数  K が与えられる。 i = 1, 2, \dots, N - K + 1 に対して順に以下の操作を行う。

  •  p_{i}, p_{i+1}, \dots, p_{i-K+1} をランダムシャッフルする

 N- K + 1 回の操作後に得られる順列の転倒数の期待値を、mod 998244353 で求めよ。

制約

  •  2 \le K \le N \le 2 \times 10^{5}

考えたこと

コンテスト本番では、まずは  O(N^{2}) な解を見出そうというのを考えた。具体的には「期待値の線形性」によって、

  •  0 \le j \lt i \lt N なる  (j, i) に対して、
  • それらの相対位置が入れ替わる確率  p(i, j) を求める
  • このとき
    •  p_{i} \lt p_{j} ならば、 1 - p(i, j) を合算する
    •  p_{i} \gt p_{j} ならば、 p(i, j) を合算する

というのをすればよいと考えた。

p(i, j) を求める

まず

  • 添字  j の要素が最初にシャッフル対象となる順序を  j' = \max(0, j - K + 1)
  • 添字  i の要素が最初にシャッフル対象となる順序を  i' = \max(0, i - K + 1)

とする。このとき、添字  (i, j) の相対位置が入れ替わるためには

  •  j' 回目のシャッフルによって先頭に入らない (先頭に入ったらその後二度とシャッフル対象とならない)
    • その確率は  \frac{K-1}{K}
  •  j'+1 回目のシャッフルによって先頭に入らない
    • その確率は  \frac{K-1}{K}
  • ...
  •  i'-1 回目のシャッフルによって先頭に入れない
    • その確率は  \frac{K-1}{K}
  •  i' 回目以降のシャッフルによって相対位置が入れ替わる (この操作後はいかなる操作を行っても 1/2)
    • その確率は、 \frac{1}{2}

ということで、

 p(i, j) = \frac{1}{2}(\frac{K-1}{K})^{i' - j'}

となる。この時点で  O(N^{2}) の解法が得られた。実際に実装してみて、サンプル 2 が合うことを確かめるところまではやった。

BIT で高速化

なぜここまでできて、この BIT による高速化ができなかったのか...思いつかなかったのは、

 p(i, j) = \frac{1}{2}(\frac{K-1}{K})^{i'} (\frac{K-1}{K})^{-j'}

という風に、 i' j' とで分解すること。こうすれば、次のようにして反転数を求めることができる。各  i に対して以下の値を合算した値を求めて、それを  i = 0, 1, \dots, N-1 に対する総和を求めればよい。


  1.  p_{i} \lt p_{j} であるような  j \lt i の個数
  2.  \frac{1}{2} (\frac{K-1}{K})^{i'} \sum_{j: j \lt i, p_{i} \lt p_{j}} (\frac{K-1}{K})^{-j'}
  3.  -\frac{1}{2} (\frac{K-1}{K})^{i'} \sum_{j: j \lt i, p_{i} \gt p_{j}} (\frac{K-1}{K})^{-j'}

これらはそれぞれ BIT を用いて更新していくことで、 O(\log N) で求められる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template <class Abel> struct BIT {
    Abel UNITY_SUM = 0;
    vector<Abel> dat;
    
    // [0, n)
    BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
    void init(int n) {
        dat.assign(n, UNITY_SUM);
    }
    
    // a is 0-indexed
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i |= i + 1)
            dat[i] = dat[i] + x;
    }
    
    // [0, a), a is 0-indexed
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
            res = res + dat[i];
        return res;
    }
    
    // [a, b), a and b are 0-indexed
    inline Abel sum(int a, int b) {
        return sum(b) - sum(a);
    }
    
    // debug
    void print() {
        for (int i = 0; i < (int)dat.size(); ++i)
            cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

mint solve(int K, const vector<int> &P) {
    mint res = 0;
    int N = (int)P.size();
    BIT<long long> bnum(N, 0);
    BIT<mint> bpro(N, 0);
    auto index = [&](int i) { return max(0, i - K + 1); };
    for (int i = 0; i < N; ++i) {
        int ti = index(i);
        int v = P[i];
        mint p = mint(K-1) / K;
        mint ip = mint(1) / p;
        mint fac = modpow(p, ti)/2;
        long long num = bnum.sum(v + 1, N);
        mint upper = bpro.sum(v + 1, N), lower = bpro.sum(0, v);
        mint add = mint(1) * num + fac * (lower - upper);
        res += add;
        bnum.add(v, 1);
        bpro.add(v, modpow(ip, ti));
    }
    return res;
}

int main() {
    int N, K;
    cin >> N >> K;
    vector<int> p(N);
    for (int i = 0; i < N; ++i) cin >> p[i], --p[i];
    cout << solve(K, p) << endl;
}