けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

JOI 春合宿 2010 day1-3 Stairs (難易度 6)

区間分割していく DP を普通にやると  O(N^{2}) になる (オレンジの出荷もそう)。それを累積和を用いて高速化する。

問題概要 (意訳)

 N 個の正の整数  H_{1}, \dots, H_{N} が与えられる。これらをいくつかの連続した区間に分割していく。ただしどの区間についても、総和が  P 以下でなければならない。

そのような分割方法の個数を 1234567 で割ったあまりを求めよ。

制約

  •  1 \le N \le 5 \times 10^{5}

考えたこと

区間を分割するタイプの問題は、DP で解けることが多い。こんな感じの DP をするのだ。

  • dp[ i ] := 最初の i 個の整数をいくつかの区間に分割する方法の個数

このとき、次のような DP 遷移が作れる

dp[ i ] += dp[ j ] (区間 [j, i) の総和が P 以下であるような j に対して)

この DP の計算量は  O(N^{2}) となる。ここまでは「オレンジの出荷」なんかも似た DP だ。これを高速化しよう。

DP 高速化

DP 遷移を改めてちゃんと書くと、こんなふうになる。

  • 区間 [j, i) の総和が  P 以下となるような最小の j を j = l(i) とすると、
  • dp[ i ] = dp[ l(i) ] + dp[ l(i) + 1 ] + ... + dp[ i - 1 ]

ここで配列 dp の累積和 sdp を導入すると、

  • dp[ i ] = sdp[ i ] - sdp[ l(i) ]

と簡潔に書ける。このようにすれば計算量は  O(N) に下がる。

コード

ここでは modint を用いることにした。また、l(i) を求める部分は「しゃくとり法」を用いた。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1234567;
using mint = Fp<MOD>;

int main() {
    long long N, P;
    cin >> N >> P;
    vector<long long> H(N);
    for (int i = 0; i < N; ++i) cin >> H[i];

    // しゃくとり法
    vector<int> l(N+1, N+1);
    int left = 0;
    long long sum = 0;
    for (int right = 0; right <= N; ++right) {
        while (left < right && sum > P) sum -= H[left++];
        l[right] = left;
        if (right < N) sum += H[right];
    }

    // 累積和を用いた DP
    vector<mint> dp(N+1, 0), sdp(N+2, 0);
    dp[0] = sdp[1] = 1;
    for (int n = 1; n <= N; ++n) {
        dp[n] = sdp[n] - sdp[l[n]];
        sdp[n+1] = sdp[n] + dp[n];
    }
    cout << dp[N] << endl;
}