けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 028 D - 注文の多い高橋商店 (赤色)

戻す DP を履修して行く!!!

問題へのリンク

問題概要

 N 個の正の整数  A_{1}, \dots, A_{N} と正の整数  M が与えられる。以下の  Q 個のクエリに答えよ。

  • 整数  K, X が与えられる
  • 以下の条件を満たす 0 以上の整数の組 ( x_{1}, \dots, x_{N}) の個数を 1000000007 で割ったあまりを求めよ
    •  0 \le x_{i} \le A_{i}
    •  x_{1} + \dots + x_{N} = M
    •  x_{K} = X

制約

  •  1 \le N, M \le 2000
  •  1 \le Q \le 5 \times 10^{5}

考えたこと

このような「N 個のうち 1 個を除外したものを解きたい」という場面では、次のようなアプローチが代表的かもしれない。

  • 左右両端からの結果を求める
  • 全体の解から 1 個のものを引く

後者のアプローチは単純に総和から値を引く程度の問題は 300 点問題などでよくみられる。しかし、DP でもそういうことができて、「戻す DP」と呼ばれたりする模様。

戻す DP

たとえば部分和数え上げ DP などはこんな感じの更新をする

  • 品物 (値が a) を 1 個追加する前の DP 配列を dp、品物を追加後の DP 配列を ndp としたとき
  • ndp[ v ] = dp[ v ] + dp[ v - a ]

このような演算を  N 回行った結果が最終結果となる。一方、品物を除去する方向性の更新もできる。逆変換をとると

  • dp[ v ] = ndp[ v ] - dp[ v - a ]

という in-place な更新によって逆変換ができる。ここで、 N 回の更新を行った結果は「 N 個の品物をどの順序で更新したとしても等しい」ということに注意しよう。よって、 N 回の更新を行った DP 配列に対して、どの品物を除去しても「残しの  N-1 個の品物についての DP 解」が求められることになる。

今回の場合

今回の DP のステップは次のように書ける (更新に用いる整数を  A とする)。

ndp[ v ] = dp[ v ] + dp[ v - 1 ] + ... + dp[ v - A ]

と書ける。

ndp[ v - 1 ] = dp[ v - 1 ] + ... + dp[ v - A ] + dp[ v - A - 1 ]

と比較することで

ndp[ v ] = ndp[ v - 1 ] + dp[ v ] - dp[ v - A - 1 ]

と簡潔に表せる (実際の DP では更新順序に注意)。これの逆変換をとると

dp[ v ] = dp[ v - A - 1 ] + ndp[ v ] - ndp[ v - 1 ]

となる。

計算量

  •  N 個をマージした結果を求める: O(NM)
  •  N 個それぞれを除去した結果を求める: O(NM)
  • 各クエリに答える: O(Q)

コード

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

vector<mint> add(const vector<mint> &dp, int M, int A) {
    vector<mint> ndp(M+1, 0);
    for (int v = 0; v <= M; ++v) {
        if (v-1 >= 0) ndp[v] += ndp[v-1];
        ndp[v] += dp[v];
        if (v-A-1 >= 0) ndp[v] -= dp[v-A-1];
    }
    return ndp;
}

vector<mint> sub(const vector<mint> &ndp, int M, int A) {
    vector<mint> dp(M+1, 0);
    for (int v = 0; v <= M; ++v) {
        if (v-A-1 >= 0) dp[v] += dp[v-A-1];
        dp[v] += ndp[v];
        if (v-1 >= 0) dp[v] -= ndp[v-1];
    }
    return dp;
}

int main() {
    // 入力
    int N, M, Q;
    cin >> N >> M >> Q;
    vector<int> a(N);
    for (int i = 0; i < N; ++i) cin >> a[i];

    // N 個の結果を総合した結果を求める
    vector<mint> all(M+1, 0);
    all[0] = 1;
    for (int i = 0; i < N; ++i) all = add(all, M, a[i]);

    // 各 i に対して a[i] を除外した結果を求める
    vector<vector<mint>> res(N);
    for (int i = 0; i < N; ++i) res[i] = sub(all, M, a[i]);

    // クエリ処理
    while (Q--) {
        int K, X;
        cin >> K >> X;
        --K;
        cout << res[K][M-X] << endl;
    }
}