けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

第一回日本最強プログラマー学生選手権-予選- F - Candy Retribution (銅色, 1000 点)

てんぷらたんのこれを思い出した!!!

yukicoder.me

問題へのリンク

問題概要

 N 要素からなる非負整数  A_1, \dots, A_N であって

  •  L \le A_1 + \dots + A_N \le R
  •  N 要素を大きい順に並べたとき、 M 番目と  M+1 番目とが等しい

という条件を満たすものの個数を  1000000007 で割ったあまりを求めよ。

制約

  •  1 \le M \lt N \le 3 \times 10^{5}
  •  1 \le L \le R \le 3 \times 10^{5}

考えたこと

まず、 L 以上  R 以下という制約を見て真っ先に

  •  R 以下のものについての個数から
  •  L-1 以下のものについての個数を引く

としてよいことがわかる。よって問題は整数  K が与えられたときに

  •  A_1 + \dots A_N \le K
  •  N 要素を大きい順に並べたとき、 M 番目と  M+1 番目とが等しい

という条件を満たすものを数え上げる問題に帰着した。ここから先は悩む。いろんな方針が考えられると思う。

  • 制約が小さければ DP でできる...それを母関数だの FFT だので高速化する系?
  • 根本的に重複組合せ的な考え方でなんとかする?

といったことをあれこれ考えて迷走してしまいそう。ここでは重複組合せを使う方向性を考えてみる。重複組合せを考えると、まずは以下のことがいえる


  •  A_1 + \dots + A_N \le K
  •  A_i は非負整数

を満たす数列の個数は  C(K + N, N) 通りである


重複組合せの亜種

重複組合せ亜種として

  •  A_1 + \dots + A_N \le K
  •  b \le A_i \le c

という感じの数え上げも、包除原理を用いて  O(N) 程度の計算時間で、解くことができることを知っておくと考えやすい。これはまさにてんぷらしゃんが yukicoder で出していた問題!!!

drken1215.hatenablog.com

上手に問題を分割

問題の条件は  M 番目と  M+1 番目とが等しいという条件なので、その値で場合分けしようと思うのが自然な気がする。でもその値を  v として、

  • 上位  M 個が  v 以上
  • 下位  N-M 個が  v 以下

になるようなものを数えようという気持ちになる。ここで上位と下位が完全に独立に考えられる状態になったら楽なのだけど、、、このままだと  v が被っているのでややこしい。そこで敢えて補集合を数え上げることにする。そのさいに上位の最小値を固定するとやりやすそう

  • 上位  M 個のうちの最小値が  v
  • 下位  N-M 個はすべて  v-1 以下

という条件を満たすものを数え上げられたらよい。ここでやりやすいのは、まず  N 要素のうちの  M 箇所を上位側に指定して、残りを下位側に指定して、最後に  C(N, M) をかけることにして、独立に考えることができる。上位と下位とで同一の要素を共有していないからだ。

...が、この辺りで僕は詰まった。上位と下位とを独立に考えられると思ったのだけど、上位の総和と下位の総和として考えられる場合が  O(K) 通りくらいあって、それぞれについて下位側では包除原理で  O(K) くらいの計算量を要するので全部で  O(K^{2}) くらいの計算量になってしまう。さらにこれを各  v についてやるので全体で  O(K^{3}) とかかかってしまいそう。これを高速化するのは思いつかなかった。

まずは「最小値が v」をどうにかする

まずは上位側の「最小値が  v」の扱い方について。これは安直には

  •  v をとるやつが  i 個の場合を求め、 i について合計する

というふうにしても良さそうだけど、もっとよくできる。すなわち

  • 最小値が  v 以上の場合の数から
  • 最小値が  v+1 以上の場合の数を引く

とすれば OK。

実際には上位側と下位側とをまとめて数える

上位と下位とをまとめて数えてしまうことにする。つまり

  •  A_{1} + \dots + A_{N-M} + A_{N-M+1} + \dots + A_{N} = K
  •  A_{1}, \dots, A_{N-M} \le v-1
  •  A_{N-M+1}, \dots, A_{N} \ge v

の場合の数を求め、そこから

  •  A_{1} + \dots + A_{N-M} + A_{N-M+1} + \dots + A_{N} = K
  •  A_{1}, \dots, A_{N-M} \le v-1
  •  A_{N-M+1}, \dots, A_{N} \ge v+1

の場合の数を求めて引けばよい。

包除原理へ

  •  A_{1} + \dots + A_{N-M} + A_{N-M+1} + \dots + A_{N} \le K
  •  A_{1}, \dots, A_{N-M} \le v-1
  •  A_{N-M+1}, \dots, A_{N} \ge v

を満たす場合を数えてみる。まず  \ge v について処理するのは簡単で、各項からあらかじめ  v を引いておくことで

  •  A_{1} + \dots + A_{N-M} + A_{N-M+1} + \dots + A_{N} \le K - Mv
  •  A_{1}, \dots, A_{N-M} \le v-1
  •  A_{N-M+1}, \dots, A_{N} \ge 0

の場合に帰着できる。二番目の条件は扱いづらいが、包除原理でできる。この部分だけなら過去にてんぷらしゃんによる yukicoder での出題がある!

drken1215.hatenablog.com

 A_{1}, \dots, A_{N-M} のうち  i 個分の条件が違反する場合の数は、 i 個分が  \ge v という条件を満たす必要があることから、

  •  (-1)^{i} C(N-M, i) \times
  •  C(K - Mv - iv + N, N) 通り

となる。これを各  i について足せばよい。この部分について  i の動く範囲をよく考えると、じつは

 K - Mv - iv \ge 0

を満たす必要があるので、実は雑に見積もっても  i \le K/v 程度ということがわかる。

計算量

 v = 0, 1, \dots, K に対して包除原理をするので一見  O(K^{2}) になるように思えるけれどもそうではなく、上記の包除原理中の  i (制約違反箇所数) の動ける範囲が  K/v 以下ということで、

  •  K(1 + 1/2 + 1/3 + \dots + 1/K) 〜 O(K\log{K})

という程度の計算量になることがわかる。

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

// 二項係数ライブラリ
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;
BiCoef<mint> bc;


long long N, M, L, R;

// 下位 N-M 個が v-1 以下、sum = K2 (K2 は K - Mv などを表す気持ち)
mint subsolve(long long v, long long K2) {
    mint res = 0;
    for (long long i = 0; i * v <= K2; ++i) {
        mint tmp = bc.com(N-M, i) * bc.com(K2 - i*v + N, N);
        if (i & 1) res -= tmp;
        else res += tmp;
    }
    return res;
}

mint solve(long long K) {
    mint res = 0;
    for (long long v = 1; v <= K; ++v) {
        mint tmp = 0;
        
        // 上位側が v 以上
        tmp += subsolve(v, K - M*v);

        // 上位側が v+1 以上
        tmp -= subsolve(v, K - M*(v+1));

        res += tmp;
    }
    return bc.com(K + N, N) - bc.com(N, M) * res;
}

int main() {
    bc.init(1100000);
    cin >> N >> M >> L >> R;
    cout << solve(R) - solve(L-1) << endl;
}