けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 115 E - LEQ and NEQ (黄色, 700 点)

間に合わなかった!!!悔しい!!!

問題概要

長さ  N の数列  A_{1}, \dots, A_{N} が与えられます。以下の条件を満たすような、長さ  N の数列  X_{1}, \dots, X_{N} の個数を 998244353 で割ったあまりを答えよ。

  •  1 \le X_{i} \le A_{i}
  •  X_{i} \neq X_{i+1}

制約

  •  2 \le N \le 5 \times 10^{5}
  •  1 \le A_{i} \le 10^{9}

考えたこと

 X_{i} \neq X_{i+1} という条件は扱いづらいので、包除原理でやると良さそう。

 N-1 個の隙間のうち、いくつかの箇所に「=」を入れると、「=」で繋がれた区間は一つに潰れる感じになる。そしてその区間については「区間内の  A_{i} の最小値」に置き換えてあげる。このように縮退してできる数列に対して、単純に掛け算を取れば OK。

というわけで、基本的には区間ごとに分割していくタイプの DP で扱えそうだ。

  • dp[i] ← 数列の最初の i 個についての場合の数

こうすると、次のような DP によって  O(N^{2}) にはなる。

dp[i] += dp[j] × (A[j:i] の最小値) ×  (-1)^{j-i-1}

とりあえず  (-1)^{j-i-1} という部分がややこしいので、

dp[i] (-1)^{i} × dp[i]

と置き直すことで、次のように変形した。

dp[i] =  - \sum_{j = 1}^{i-1} (dp[j] × (A[j:i] の最小値))

この段階でとりあえずサンプル 2 が合うことを確かめておいた。

DP 高速化

たとえば  A = (1, 2, 5, 4, 3) のとき

  • dp[4] = -(dp[0] * 1 + dp[1] * 2 + dp[2] * 4 + dp[3] * 4)

という感じになっている。ここから dp[5] を考えると、 A_{4} = 3 を考慮に加えることになって、こんな感じになる。

  • dp[5] = -(dp[0] * 1 + dp[1] * 2 + dp[2] * 3 + dp[3] * 3 + dp[4] * 3)

この式は分解すると、

dp[5] = -(dp[0] * 1 + dp[1] * 2) - (dp[2] + dp[3] + dp[4]) * 3
= dp[2] - (dp[2] + dp[3] + dp[4]) * 3

という風に理解できる。つまり、(1, 2, 4, 4) に対して 3 がどこに挿入されるかを考えて、挿入される箇所の前の部分では過去の dp 値、挿入される箇所の後の部分では「累積和 ×  A の値」という感じになっているのだ。

よって、挿入される箇所が特定できれば、DP を高速化できる。僕はこの挿入箇所の特定を、

の遅延評価セグメント木を用意して二分探索する、みたいなことをやって TLE してしまった。

しかしよく考えたら、stack を使えば線形時間でできるのであった!!!

コード

計算量は  O(N) になる。

#include <bits/stdc++.h>
using namespace std;
using pll = pair<long long, long long>;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator < (const Fp& r) const noexcept {
        return this->val < r.val;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    stack<pll> st;
    st.push({0, 0});
    vector<mint> dp(N+1, 0), sdp(N+2, 0);
    dp[0] = 1, sdp[1] = 1;
    for (int i = 1; i <= N; ++i) {
        long long up = A[i-1];
        while (!st.empty() && st.top().first >= up) st.pop();
        long long num = st.top().second;
        st.push({up, i});
            
        if (num > 0) dp[i] += dp[num];
        dp[i] -= (sdp[i] - sdp[num]) * up;
        sdp[i+1] = sdp[i] + dp[i];
    }
    mint res = dp[N];
    if (N % 2 == 1) res = -res;
    cout << res << endl;
}

その他の解法

座標圧縮だったり、遅延評価セグメント木を使ったりなどでもできる模様。