けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 049 D - Convex Sequence (橙色, 1000 点)

最初は二次元 FFT が必要な気分になっていて右往左往していた。個数制限なしナップサック問題になるのは面白かった!

問題概要

正の整数  N, M が与えられる。長さ  N の非負整数列  A_{1}, \dots, A_{N} であって

  •  A_{1} + \dots + A_{N} = M
  •  2A_{i} \le A_{i-1} + A_{i+1}

という条件を満たすものの個数を、1000000007 で割ったあまりを求めよ。

制約

  •  1 \le N, M \le 10^{5}

考えたこと

要するに「差分が広義単調増加」となるような数列の数え上げ問題ということになる。ざっくり、下に凸な形状をすることになる。なお、0-indexed で考えることにする。

なので、とりあえず最小値を固定して考えたくなる。仮に  A_{0} の値を固定して考えたいと思っても、最初に値を下がっていったとして、「最小値が非負になるようにする」という部分を扱うのは手に負えなさそうなのだ。だから最初に最小値の方を固定するのは理にかなっている気がする。

最小値を  v とし、最小値をとる index を  a とする。最小値を複数とりうる場合は、そのうちの最小の index を  a とすることにした。そうすると、左側と右側に関する問題に分けられる。左側については次のようになる。総和を仮に  m と置く。

  •  x_{0} = v
  •  x_{0} + x_{1} + \dots + x_{a} = m
  •  x_{i+1} - x_{i} \le x_{i+2} - x_{i+1}
  •  x_{i} \ge v + 1

を満たす数列  x_{0}, x_{1}, \dots, x_{a} の個数を求める問題となる。右側については

  •  y_{0} = v
  •  y_{1} + \dots + y_{N-a-1} = M - m
  •  y_{i+1} - y_{i} \le y_{i+2} - y_{i+1}
  •  y_{i} \ge v

を満たす数列  y_{0}, y_{1}, \dots, y_{N-a-1} の個数を求める問題となる。

下に凸な数列の個数

まずは左側の数列の個数を求める問題を考える。階差数列が単調増加になってるような数列の個数を求める問題となる。これは

「階差数列の階差数列が非負整数」

とも言い換えられる。それを  d_{1}, \dots, d_{a} とおくと、

  •  x_{0} = v
  •  x_{1} = v + d_{1}
  •  x_{2} = v + 2d_{1} + d_{2}
  •  x_{3} = v + 3d_{1} + 2d_{2} + d_{3}
  •  x_{4} = v + 4d_{1} + 3d_{2} + 2d_{3} + d_{4}
  • ...
  •  x_{a} = v + a d_{1} + (a-1) d_{2} + \dots + d_{a}

となる。これらの総和が  m になるというのは、次のような個数制限なしナップサック問題に帰着される ( v も変数としてしまっている)。

  • 整数  a+1 が非負整数個 ( v 個)
  • 整数  \frac{1}{2}a(a+1) が 1 個以上 ( d_{1} 個)
  • 整数  \frac{1}{2}(a-1)a が非負整数個 ( d_{2} 個)
  • ...
  • 整数  1 が非負整数個 ( d_{a} 個)

を用いて総和を  m にする方法が何通りあるか?

左と右を合わせる

さらに左側の問題と右側の問題とをマージしてしまうと、総合して次のような個数制限なしナップサック問題と言える。ここで、 b = N-1-a としている。また、 d_{1} \ge 1 という条件を  d_{1} \ge 0 に変換する処理も行っている。


  • 整数  N が非負整数個
  • 整数  \frac{1}{2}a(a+1) が非負整数個
  • 整数  \frac{1}{2}(a-1)a が非負整数個
  • ...
  • 整数  1 が非負整数個
  • 整数  \frac{1}{2}b(b+1) が非負整数個
  • 整数  \frac{1}{2}(b-1)b が非負整数個
  • ...
  • 整数  1 が非負整数個

を用いて総和を  m - \frac{1}{2}a(a+1) にする方法が何通りあるか?


あとはこれを  a = 0, 1, \dots, 2\sqrt{M} について走査する方法を考える。 a \le 2\sqrt{M} について考えれば十分なことに注意する。

解法 (1):左右から累積和 (僕の方法)

 a に対して

  • 整数  \frac{1}{2}a(a+1) が非負整数個
  • 整数  \frac{1}{2}(a-1)a が非負整数個
  • ...
  • 整数  1 が非負整数個

の総和が  m になるような方法の個数を  f(a, m) と表し、

  • 整数  N が非負整数個
  • 整数  \frac{1}{2}a(a+1) が非負整数個
  • 整数  \frac{1}{2}(a-1)a が非負整数個
  • ...
  • 整数  1 が非負整数個

の総和が  m になるような方法の個数を  g(a, m) と表すことにする。 f(a, m) g(a, m) を求める作業は  O(\sqrt{M}M) でできる ( f, g a \ge 2\sqrt{M} の範囲では収束することに注意)。

このとき、各  a = 0, 1, \dots, 2\sqrt{N} m = 0, 1, \dots, M - \frac{1}{2}a(a+1) に対して

 f(a, m) \times g(N-a-1, M-\frac{1}{2}a(a+1) - m)

の総和を求めていけばよい。全体を通して計算量は  O(sqrt{M}M) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

// 個数制限なしナップサックにおいて、値 v を追加する
vector<mint> add(vector<mint> dp, int v) {
    for (int i = v; i < dp.size(); ++i) dp[i] += dp[i-v];
    return dp;
}

int main() {
    int N, M;
    cin >> N >> M;
    int MAX = 1;
    while (MAX * (MAX+1) / 2 <= M) ++MAX;
    vector<vector<mint>> f(MAX+1, vector<mint>(M+1, 0));
    f[0][0] = 1;
    auto g = f;
    g[0] = add(f[0], N);
    for (int a = 1; a <= MAX; ++a) {
        f[a] = add(f[a-1], a*(a+1)/2);
        g[a] = add(g[a-1], a*(a+1)/2);
    }
    mint res = 0;
    for (int a = 0; a <= MAX && a < N; ++a) {
        int b = min(N-a-1, MAX);
        int sa = a * (a+1) / 2;
        for (int m = 0; m <= M - sa; ++m) {
            res += f[a][m] * g[b][M-sa-m];
        }
    }
    cout << res << endl;
}

解法 (2):戻す DP

戻す DP でもできる。各  a = 0, 1, \dots, 2\sqrt{M} に対して、 b = N-a-1 として、

  • 整数  N が非負整数個
  • 整数  \frac{1}{2}a(a+1) が非負整数個
  • 整数  \frac{1}{2}(a-1)a が非負整数個
  • ...
  • 整数  1 が非負整数個
  • 整数  \frac{1}{2}b(b+1) が非負整数個
  • 整数  \frac{1}{2}(b-1)b が非負整数個
  • ...
  • 整数  1 が非負整数個

を用いて総和を  m - \frac{1}{2}a(a+1) にする方法が何通りあるかを総和したい。

 a を増やしたとき

  •  \frac{1}{2}a(a+1) を追加
  •  \frac{1}{2}(N-a)(N-a+1) を削除

という挙動になる。これは戻す DP で実現できる。この場合も計算量は  O(\sqrt{M}M) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

// 個数制限なしナップサックにおいて、値 v を追加する
vector<mint> add(vector<mint> dp, int v) {
    for (int i = v; i < dp.size(); ++i) dp[i] += dp[i-v];
    return dp;
}

// 個数制限なしナップサックにおいて、値 v を削除する
vector<mint> sub(vector<mint> dp, int v) {
    for (int i = dp.size()-1; i >= v; --i) dp[i] -= dp[i-v];
    return dp;
}

int main() {
    int N, M;
    cin >> N >> M;
    int MAX = 1;
    while (MAX * (MAX+1) / 2 <= M) ++MAX;
    vector<mint> dp(M+1, 0);
    dp[0] = 1;
    dp = add(dp, N);
    for (int b = 1; b <= min(N-1, MAX); ++b) dp = add(dp, b*(b+1)/2);
    mint res = dp[M];
    for (int a = 1; a*(a+1)/2 <= M && a < N; ++a) {
        int b = min(N-a, MAX);
        dp = add(dp, a*(a+1)/2);
        dp = sub(dp, b*(b+1)/2);
        res += dp[M-a*(a+1)/2];
    }
    cout << res << endl;
}