けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 380 G - Another Shuffle Window (3D, 青色, 575 点)

公式解説の方がシンプルだった。

問題概要

 (1, 2, \dots, N) の順列  P と、整数  K が与えられる。

 P の連続する  K 個の要素からなる区間( N - K + 1 通りある)をランダムに選び、さらにその区間をランダムシャッフルする。

最終的な順列の転倒数の期待値を mod 998244353 で求めよ。

制約

  •  1 \le K \le N \le 2 \times 10^{5}

考えたこと

公式解説では、区間を固定したときの期待値を求めて、その結果を区間を動かしながら求めるような方法をしていた。そっちの方がシンプルだった。

僕は、各ペア  (P_{i}, P_{j}) ( i \lt j) に対して、順序が転倒する確率を求めて、その総和を求める方針をとった。この確率は次のように求められる。


  •  P_{i} \gt P_{j} のとき:
    •  j - i \ge K のとき:1
    • そうでないとき: 1 - \frac{\min(i, N-K)}{2(N-K+1)} + \frac{\max(j-K, -1)}{2(N-K+1)}
  •  P_{i} \lt P_{j} のとき:
    •  j - i \ge K のとき:0
    • そうでないとき: \frac{\min(i, N-K)}{2(N-K+1)} - \frac{\max(j-K, -1)}{2(N-K+1)}

よって、 j を固定したときの各  i に対する上記の値の総和を素早く求めるために、次の 3 種類の BIT を持つことにした。

  •  j - i \ge K であるような  i についての、添字  v = P_{i} の位置には 1 が立っている (それ以外は 0) ような BIT
  •  j - i \lt K であるような  i についての、添字  v = P_{i} の位置には 1 が立っている (それ以外は 0) ような BIT
  •  j - i \lt K であるような  i についての、添字  v = P_{i} の位置には  \min(i, N-K) が立っている (それ以外は 0) ような BIT

これらの情報があれば、高速に求められる。計算量は  O(N \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

// BIT
template <class Abel> struct BIT {
    Abel UNITY_SUM = 0;
    vector<Abel> dat;
    
    // [0, n)
    BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
    void init(int n) {
        dat.assign(n, UNITY_SUM);
    }
    
    // a is 0-indexed
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i |= i + 1)
            dat[i] = dat[i] + x;
    }
    
    // [0, a), a is 0-indexed
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
            res = res + dat[i];
        return res;
    }
    
    // [a, b), a and b are 0-indexed
    inline Abel sum(int a, int b) {
        return sum(b) - sum(a);
    }
    
    // debug
    void print() {
        for (int i = 0; i < (int)dat.size(); ++i)
            cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, K;
    cin >> N >> K;
    vector<int> P(N);
    for (int i = 0; i < N; i++) cin >> P[i];

    mint res = 0;
    BIT<long long> ex_num(N + 2);
    BIT<long long> in_num(N + 2);
    BIT<mint> in_sum(N + 2);
    for (int j = 0; j < N; j++) {
        // 区間の左端を j - K >= 0 側に移す
        if (j - K >= 0) {
            int v = P[j - K];
            in_num.add(v, -in_num.sum(v, v + 1));
            in_sum.add(v, -in_sum.sum(v, v + 1));
            ex_num.add(v, 1);
        }

        // j を固定したときの、各 i に対する (P[i], P[j]) 部分の転倒数の期待値を求める
        long long num_ex_upper = ex_num.sum(P[j]+1, N+2);
        long long num_in_upper = in_num.sum(P[j]+1, N+2);
        long long num_in_lower = in_num.sum(0, P[j]);
        mint sum_in_upper = in_sum.sum(P[j]+1, N+2);
        mint sum_in_lower = in_sum.sum(0, P[j]);
        res += mint(num_ex_upper);
        res += (mint(max(j-K, -1)) / 2 / (N - K + 1) + 1) * num_in_upper - sum_in_upper / 2 / (N - K + 1);
        res -= (mint(max(j-K, -1)) / 2 / (N - K + 1)) * num_in_lower - sum_in_lower / 2 / (N - K + 1);

        // 区間の右端 (P[j]) を反映させる
        in_num.add(P[j], 1);
        in_sum.add(P[j], mint(min(j, N - K)));
    }
    cout << res << endl;
}