けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

MUJIN 2018 F - チーム分け (600 点)

今なら解ける!!!

問題へのリンク

問題概要

 N 人を何チームかに分けたい (チーム同士は区別しない)。 i 人目の人は  a_i 人以下のチームに入るようにする必要がある。そのようなチームの分け方は何通りあるか、 998244353 で割ったあまりを求めよ。

制約

  •  1 \le N \le 1000
  •  1 \le a_i \le N

考えたこと

もし無制約だったらいわゆるスターリング数っぽい問題!!!
そしてこの手の DP は昔は見えなかったけど、今なら見える!!!
とりあえず  a を大きい順に並べておく。

このとき、a を大きい順に見て、 i 人目の所属するグループ人数が  N, N-1, \dots, 1 のどれになるのかを割り当てていくような問題になる。こういう割り当てっぽい構造をした DP は難しいけど慣れれば見える。

箱根駅伝 DP とも言うべきか。。。

  • dp[ x ][ y ] := x 人以上のメンバーで構成されたグループが、y 人分構成済みの状態になる場合の数

とする。つまり「x」以上のラベルを振る相手というのはそもそも a_i >= x となっている i に限られて、その中で y 人分は割り当て済みで、残りは未割当みたいな状態である。

重要なこととして、未割当な人たちは a_i >= x な人たちの一員なので、x を下げていってもこれから割り当てることが可能な人たちである。この手の順番の決め方は上手くやる必要がある。選択肢が広い人たちから順番に決めて行くことで、後にいくらでも調整が効くようにする感じ。

とりあえずそうすると、x 人以上 OK な人が y 人以上いる場合、各 k に対して

  • dp[ x ][ y ] += dp[ x + 1 ][ y - kx ] × (kx 人を選ぶ場合の数) × (kx 人を x 人ずつのグループ分けする場合の数) (x + 1 人以上 OK が y - kx 人以上いるとき)

という感じになる。計算量的には、遷移数が調和級数和のような感じになっていて、結局全体として  O(N^{2}\log{N}) になる。

#include <iostream>
#include <vector>
using namespace std;

const int MOD = 998244353;
const int MAX = 210000;
long long fac[MAX], finv[MAX], inv[MAX];
void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for(int i = 2; i < MAX; i++){
        fac[i] = fac[i-1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD;
        finv[i] = finv[i-1] * inv[i] % MOD;
    }
}
long long COM(int n, int k){
    if(n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n-k] % MOD) % MOD;
}

long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

long long modinv(long long a, long long mod) {
    long long b = mod, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    u %= mod;
    if (u < 0) u += mod;
    return u;
}

int main() {
    COMinit();
    int N; cin >> N;
    vector<int> a(N);
    for (int i = 0; i < N; ++i) cin >> a[i];

    // nums[v] := v 人以上 OK な人数
    vector<long long> nums(N+2, 0);
    for (int i = 0; i < N; ++i) nums[a[i]]++;
    for (int i = N; i >= 0; --i) nums[i] += nums[i+1];

    // DP
    vector<vector<long long> > dp(N+2, vector<long long>(N+1, 0));
    dp[N+1][0] = 1;
    for (long long x = N; x >= 1; --x) {
        for (long long y = 0; y <= nums[x]; ++y) {
            for (long long k = 0; k <= N; ++k) {
                long long y2 = y - x * k;
                if (y2 < 0) break;
                if (y2 > nums[x+1]) continue;
                long long choose = COM(nums[x] - y2, x * k);
                long long fact = fac[x*k] * modinv(modpow(fac[x], k, MOD), MOD) % MOD * finv[k] % MOD;
                dp[x][y] += dp[x+1][y2] * choose % MOD * fact % MOD;
                dp[x][y] %= MOD;
            }
        }
    }
    cout << dp[1][N] << endl;
}