けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 108 E - Random IS (赤色, 700 点)

変なところでハマらないようにしたい...

問題概要

 1, 2, \dots, N の順列  a_{1}, \dots, a_{N} が与えられる。いま、これらの順列の各要素に印をつけていくことを考える。ただし、「印のついた要素が左から順に単調増加となるように並んでいる」という条件を常に満たす必要がある。以下のことを繰り返す

  • まだ印のついていない要素のうち、それに印をつけても条件を満たすような要素が  K 個あるとする
  •  K = 0 ならば作業を停止する
  •  K \gt 0 ならば、そのうちの 1 個をランダムに選んで印をつける

作業停止状態における印の個数の期待値を mod 1000000007 で求めよ。

制約

  •  1 \le N \le 2000

最初に考えたこと (嘘)

とりあえず「どの要素を新たに追加しても単調増加である状態が破られる状態」を極大であるということにする。このとき、次のことが言えるのではないかと思った (実際は嘘)

  • 極大な部分列であって、要素数 k であるものの個数を  c_{k} とする
  • このとき、期待値は  \frac{\sum_{k= 1}^{N} c_{k} \times k \times (k!)}{\sum_{k=1}^{N} c_{k} \times (k!)} となる

よって、 c_{1}, \dots, c_{N} が求められれば OK。はじめは、LIS と同じような解法で行けるのではないかと考えた。つまり

  • dp[ i ][ j ][ k ] := 最初の i 個の要素から j 個選んで、それが極大であるもののうち、その最後の要素の値が k であるものの個数

このままだと  O(N^{3}) で、なんとか高速化できないかと考えていた。しかし、そもそも最初の考察が嘘だった。なぜなら、最終的な極大部分列が同一であっても、印をつける順序によって、出現しやすさが異なってくるのだ。

作戦変更

完全に嘘だったので軌道修正!!!!! 印が打たれていく過程にもちゃんと注目していかないといけないことがわかった。印を数値 v に新たに打つ操作は、v より小さい範囲で最大の数値を l、v より大きい範囲で最小の数値を r としたときに、「数値 l と r とで挟まれた区間を、数値 v のところで左右に分割する操作」に対応する。そのことに着目すると、次のような区間 DP が生える。

  • dp[ l ][ r ] := 数値 l と数値 r には印を打たれた状態で、その間の数値区間で印を打つ回数の期待値

このようにしたとき、遷移は次のように表せる。ただし、順列 a の先頭に 0、末尾に N+1 を挿入し、0 と N+1 には確実に印を打つものとして考える。また、順列 a の逆順列を ia とする。

  • l >= r のときは dp[ l ][ r ] = 0
  • l < r かつ ia[l] > ia[r] のときは dp[ l ][ r ] = 0
  • そうでないとき、dp[ l ][ r ] += (dp[ l ][ v ] + dp[ v + 1 ][ r ]) / K + 1 (l < v < r かつ ia[ l ] < ia[ v ] < ia[ r ] のとき。また、これを満たす v の個数を K とする)

のように表せる。このままだと  O(N^{3}) になる。しかし、めちゃくちゃ上手にやると  O(N^{2} \log N) になる。上の遷移式で

  • l < v < r
  • ia[ l ] < ia[ v ] < ia[ r ]

の両方を満たす v についての総和が素早く求められればいいのだが、これらを両立するのが一見難しそうに思える。しかし、

  • 普通の区間 DP は、添字範囲が狭い方から順に DP テーブルを埋めていく

  • 今回の区間 DP は、添字に ia をとった世界において添字範囲が狭い方から順に DP テーブルを埋めていく

という風にすれば上手くいく。このとき、

  • l < v < r は満たすが
  • ia[ l ] < ia[ v ] < ia[ r ] は満たさない

というような dp[ l ][ v ] や dp[ v ][ r ] については未更新で 0 が入っている状態になる!!!

よって、BIT などを用いることで高速化できる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

template <class Abel> struct BIT {
    Abel UNITY_SUM = 0;
    vector<Abel> dat;
    
    // [0, n)
    BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
    void init(int n) {
        dat.assign(n, UNITY_SUM);
    }
    
    // a is 0-indexed
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i |= i + 1)
            dat[i] = dat[i] + x;
    }
    
    // [0, a), a is 0-indexed
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
            res = res + dat[i];
        return res;
    }
    
    // [a, b), a and b are 0-indexed
    inline Abel sum(int a, int b) {
        return sum(b) - sum(a);
    }
    
    // debug
    void print() {
        for (int i = 0; i < (int)dat.size(); ++i)
            cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

mint naive(int N, vector<int> a, vector<bool> used) {
    vector<int> ids;
    mint res = 0;
    for (int i = 1; i <= N; ++i) {
        if (used[i]) continue;
        vector<int> ch;
        for (int j = 1; j <= N; ++j) if (used[j] || j == i) ch.push_back(a[j]);
        auto ch2 = ch;
        sort(ch2.begin(), ch2.end());
        if (ch == ch2) ids.push_back(i);
    }
    if (ids.empty()) return res;
    for (auto id : ids) {
        used[id] = true;
        res += naive(N, a, used) + 1;
        used[id] = false;
    }
    return res / ids.size();
}

int main() {
    int N;
    cin >> N;
    vector<int> a(N+2), ia(N+2);
    a[0] = 0, a[N+1] = N+1, ia[0] = 0, ia[N+1] = N+1;
    for (int i = 0; i < N; ++i) cin >> a[i+1], ia[a[i+1]] = i+1;
    
    vector<vector<int>> num(N+3, vector<int>(N+3, 0));
    for (int i = 0; i < N+2; ++i) {
        BIT<int> bit(N+5);
        for (int j = i+1; j < N+2; ++j) {
            int left = a[i], right = a[j];
            if (left < right) num[left][right] = bit.sum(left, right);
            bit.add(right, 1);
        }
    }

    vector<vector<mint>> dp(N+3, vector<mint>(N+3, 0));
    vector<BIT<mint>> left(N+3, BIT<mint>(N+3)), right(N+3, BIT<mint>(N+3));
    for (int bet = 1; bet <= N+1; ++bet) {
        for (int i = 0; i+bet <= N+1; ++i) {
            int l = a[i], r = a[i+bet];
            if (l >= r) continue;
            int K = num[l][r];
            if (K > 0) {
                dp[l][r] = (left[l].sum(l+1, r) + right[r].sum(l+1, r)) / K + 1;
                left[l].add(r, dp[l][r]), right[r].add(l, dp[l][r]);
            }        
        }
    }
    cout << dp[0][N+1] << endl;
}