けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 067 E - Grouping (600 点)

slack 勉強会で 600 点の DP として話題になってやってみたん。
DP 自体は素朴だけど、計算量解析含めると 700 点でもいい気はするのんな。

Grouping 問題へのリンク

問題概要 (ARC 067 E)

 N 人をグループ分けしたい。 N 人は互いに区別される。

  • どのグループの人数も  A 人以上  B 人以下である
  •  i = A, A+1, ..., B について、 i 人のグループの個数は  C 個以上  D 個以下である

を満たすようにグループ分けする方法の数を 1000000007 で割ったあまりで求めよ。

制約

  •  1 \le N \le 1000

解法

まずどのグループの人数も  A 人の場合を考えてみる (このとき  N A の倍数でなければならない)。

まず各グループを区別できるものとして考えてみると、 k = N/A として、それぞれの人を  k 個に箱に均等に振り分ける方法は

 \frac{(kA)!}{(A!)^{k}} 通り

になる。実際は箱同士の区別はないので、ここから  k! で割って

 \frac{(kA)!}{(k!)(A!)^{k}} 通り

になる。これを踏まえて以下のような DP をする:


dp[ i ][ j ] :=  N 人のうち  i 人を選んで、どのグループの人数も  A-j 人未満になるようにグループ分けする場合の数

 k = 0,  k = C, C+1, ..., D に対して、 i + (j + A)k \le N ならば、

dp[ i + (j + A)k ][ j + 1 ] += dp[ i ][ j ] *  _{N-i}{\mathrm C}_{(j+A)k} *  \frac{((j+A)k)!}{(k!)((j+A)!)^{k}}


とする。ここで気になるのが、dp[ i ][ j ] から dp[ ナントカ ][ j + 1 ]へと進むときに、(j + A) 人以下のすでにあるグループと、(j + A) 人の新たに追加したグループとで、互いに区別をなくすために何かで割らないといけないような気がしてくるのだが、実は「人数の異なるグループ同士は元々区別できるから割らなくて良い」ことに気づく。

例えば、8 人を 4 人と 4 人に分けるときは、8C4 を 2 で割らないといけないが、8 人を 3 人と 5 人とに分けるときは、8C3 でよい。

計算量解析

この DP、一見  O(N^{3}) でヤバそうなのだが、 k の動く範囲にからくりがあります。 k が本当に  C 以上  D 以下をすべて動くなら、 O(N^{3}) なのですが、 k の動く範囲は、 j が大きくなると調和級数的に小さくなるので結局  O(N^{2}\log{N}) になります。

簡単のため、残り人数を  N 人で、 j = 1, 2, 3, ... としてみて、何グループできるのかを考えてみると、

 N, N/2, N/3, N/4, ...

ずつになります。これを総和すると、実は  O(N\log{N}) になります。理由は、

 1 + 1/2 + 1/3 + ... + 1/N = O(\log{N})

から来ているのですが、これは、 \frac{1}{x}積分すると  \log{x} になることから来ています。

コード

#include <iostream>
using namespace std;

const int MAX = 210000;
const int MOD = 1000000007;

long long fac[MAX], finv[MAX], inv[MAX];
void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++){
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
        finv[i] = finv[i - 1] * inv[i] % MOD;
    }
}
long long COM(int n, int k){
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

long long pow(long long a, long long n, long long m) {
    if (n == 0) return 1 % m;
    long long t = pow(a, n / 2, m);
    t = mod(t * t, m);
    if (n & 1) t = mod(t * a, m);
    return t;
}


long long dp[1100][1100]; // i 人を j+A 人未満のグループで分ける方法

int main() {
    long long N, A, B, C, D;
    cin >> N >> A >> B >> C >> D;

    COMinit();
    for (int i = 0; i < 1100; ++i) for (int j = 0; j < 1100; ++j) dp[i][j] = 0;
    dp[0][0] = 1;
    for (int i = 0; i <= N; ++i) for (int j = 0; j <= B - A + 1; ++j) {
        int num = A + j;

        // num 人グループが 0 個
        dp[i][j + 1] = (dp[i][j + 1] + dp[i][j]) % MOD;

        // num 人グループが C 以上 D 以下
        for (int k = C; k <= D; ++k) {
            int all = i + num*k;
            if (all > N) break;
            long long choose = COM(N - i, num * k);
            long long divide = fac[num*k] * pow(finv[num], k, MOD) % MOD;
            long long nonkubetsu = finv[k];
            long long mul = choose * divide % MOD * nonkubetsu % MOD;

            dp[all][j + 1] = (dp[all][j + 1] + dp[i][j] * mul % MOD) % MOD;
        }
    }

    cout << dp[N][B - A + 1] << endl;
}