けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 159 F - Knapsack for All Segments (青色, 600 点)

ごちゃごちゃとばぐらせながら何とか通した...

問題へのリンク

問題概要

 N 要素の数列  A_{1}, \dots, A_{N} が与えられる。この数列の  N(N+1)/2 通りの区間それぞれについての

  • 区間内の要素の部分集合であって総和が  S であるものの個数

の総和を求め、998244353 で割ったあまりを求めよ。

制約

  •  1 \le N, S \le A_{i} \le 3000

考えたこと

まず、区間が数列全体の場合についての答えなら、普通の部分和問題を解くような DP で解くことができる。その場合の計算量は  O(NS) となる。しかし計算量的にはもうこれ以上増やせないことがわかる。愚直にやったのでは  O(N^{3}S) となる。なんとかしなければならない。

まずは、この手の問題でよくあることとして、見方を変えるというのがある。問題は「各区間についての総和」だが、逆に

  • 和が S となるような要素の選び方それぞれについて
  • その要素たちを包含するような区間が何個あるかを合計する

という風に考えてみる。少し可能性ありそうな見た目になった。この個数は具体的には、

  • 和が S となる要素の選び方を一つとってきたとき、
  • それらを包含する極小な区間が [left, right) であるとき、
  • 包含する区間の個数は (left + 1) × (N - right + 1) 通りある

ということがわかるので、これを合計すればよい。左端を固定して考えたくなるのだが、それでは計算量が間に合わないので、何とか工夫することを考える。やり方はいくつかありそう。

解法(1): いわゆる耳 DP

状態遷移を意識する DP。今回は

  • まだ左端を選択していない状態 (状態 0)
  • 左端は選択済みだけど右端は選択していない状態 (状態 1)
  • 右端も選択済みの状態 (状態 2)

という状態遷移を素直に落とし込めば OK。通称、耳 DP。

  • dp[ i ][ j ][ t ] := 最初の i 個の要素について、総和が j で、状態が t である場合の場合の数

とする。左端の重みは i + 1、右端の重みは N - i とする感じで遷移すれば OK。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int N, S;
vector<int> A;

mint solve() {
    vector<vector<vector<mint>>> dp(N+1, vector<vector<mint>>(6500, vector<mint>(3, 0)));
    dp[0][0][0] = 1;
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j <= S; ++j) {
            // from 0
            dp[i+1][j][0] += dp[i][j][0];
            dp[i+1][j+A[i]][1] += dp[i][j][0] * (i+1);
            dp[i+1][j+A[i]][2] += dp[i][j][0] * (i+1) * (N-i);

            // from 1
            dp[i+1][j][1] += dp[i][j][1];
            dp[i+1][j+A[i]][1] += dp[i][j][1];
            dp[i+1][j+A[i]][2] += dp[i][j][1] * (N-i);

            // from 2
            dp[i+1][j][2] += dp[i][j][2];
        }
    }
    return dp[N][S][2];
}

int main() {
    cin >> N >> S;
    A.resize(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    cout << solve() << endl;
}

解法(2): 本番でやった

僕は DP の定義自体に、左端の選択肢に関する場合の数も含めてしまうことにした。

  • dp[ i ][ j ] := 最初の i 個の要素の中から総和を j にする選び方と、その選んだ要素をすべて包含するような区間の左端の取り方との組の個数

として定義した。dp の更新は

  • dp[ i + 1 ][ A[ i ] ] += i + 1 (左端の選び方も含む)
  • j > 0 について、dp[ i + 1 ][ j + A[ i ] ] += dp[ i ][ j ] (左端は固定済み)

という風にできる。そして、dp の第二添字が S になる瞬間について、すなわち dp[ i ][ j ] が dp[ i + 1 ][ S ] へと更新されるタイミングで

  • dp[ i ][ j ] × (N - i)

の値を答えに合算していく。計算量は  O(NS)

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int N, S;
vector<int> A;

mint solve() {
    mint res = 0;
    vector<vector<mint>> dp(N+1, vector<mint>(S+1, 0));
    dp[0][0] = 1;
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j <= S; ++j) dp[i+1][j] += dp[i][j];
        if (A[i] <= S) dp[i+1][A[i]] += i+1;
        for (int j = 1; j <= S; ++j) if (A[i] + j <= S) dp[i+1][j+A[i]] += dp[i][j];

        if (A[i] == S) res += (i+1) * (N-i);
        for (int j = 1; j <= S; ++j) if (A[i] + j == S) res += dp[i][j] * (N - i);
    }
    return res;
}

int main() {
    cin >> N >> S;
    A.resize(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    cout << solve() << endl;
}