けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

JOI 予選 2019 F - 座席 (AOJ 0657) (3D, 難易度 11)

挿入 DP。。。TDPC O - 文字列を複雑にした問題。アイディアはシンプルだけど詳細詰めが重たい。。。

問題へのリンク

問題概要

 N 個の正の整数  A_1, A_2, \dots, A_N が与えられる。

  •  1 A_1
  •  2 A_2
  • ...
  •  N A_N

を合わせた  A_1 + A_2 + \dots + A_N 個の数を並べる方法のうち、どの隣り合う箇所も数値の差が  2 以上となっているものの個数を  10007 で割ったあまりで求めよ。ただし同じ数同士は区別するものとする。

制約

  •  1 \le N \le 100
  •  1 \le A_i \le 4

考えたこと

同じ数字同士は区別しないものとして考える。そうしておいて最後に  A_1! A_2! \dots A_N! をかければよい。

「同じ数字が隣り合ってはいけない」という条件なら、TDPC O - 文字列とほぼ同じになる。

まず 1 を並べて、2 を挿入しながら並べて、3 を挿入しながら並べて...を N まで行う。このとき、途中経過であれば、「i と i とが隣り合っている」「i と i+1 とが隣り合っている」場所があっても構わないことに注意する。

そこで、


dp[ i ][ a ][ b ][ c ][ d ] := 1 から i までの数を並べたとき、

  • j <= i-1 として、(j, j) や (j-1, j) の箇所が  a 箇所
  • j <= i-2 として、(i-1, j) の箇所が  b 箇所 (条件を満たしている箇所)
  • (i-1, i) が  c 箇所
  • (i, i) が  d 箇所
  • それ以外が  e = s - a - b - c - d 箇所 (条件を満たしている箇所)

となっているようなものの場合の数
(i 文字目までで文字数が  s 個だったとする)


とする。このとき、挿入可能箇所は s + 1 箇所ある。 A_{i+1} 個の "i+1" をそれぞれ s + 1 箇所のうちのどこに挿入するかを考える。ひとまず重複を考えずにどこに挿入するかだけを考える。

  • タイプ a の箇所から  a_2 箇所
  • タイプ b の箇所から  b_2 箇所
  • タイプ c の箇所から  c_2 箇所
  • タイプ d の箇所から  d_2 箇所
  • タイプ e の箇所から  e_2 箇所

に挿入するとき

  • その場合の数は、 aC a_2 ×  bC b_2 ×  cC c_2 ×  dC d_2 ×  eC e_2 通り
  • 重複を考えるとその分の係数は、 s_2 = a_2 + b_2 + c_2 + d_2 + e_2,  r_2 = A_{i+1} - s_2 として、 r_2H s_2 通り

であり、挿入後の状態は

  • タイプ a:  a + c + d - a_2 - c_2 - d_2 箇所
  • タイプ b:  2(a_2 + e_2) + b_2 + c_2 箇所
  • タイプ c:  b_2 + c_2 + 2d_2 箇所
  • タイプ d:  r_2 箇所
  • タイプ e:  b + e - b_2 - e_2 箇所

となる。以上をまとめて DP を実装する。

#include <iostream>
#include <vector>
using namespace std;

const int MAX = 2100;
const int MOD = 10007;

long long fac[MAX], finv[MAX], inv[MAX];
void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for(int i = 2; i < MAX; i++){
        fac[i] = fac[i-1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD;
        finv[i] = finv[i-1] * inv[i] % MOD;
    }
}
long long COM(int n, int k){
    if(n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n-k] % MOD) % MOD;
}

long long dp[110][410][10][10][5] = {0}; // a <= 401, b <= 8, c <= 8, d <= 3
int main() {
    COMinit();
    int N; cin >> N;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    
    dp[0][0][0][0][0] = 1;
    int s = 0;
    for (int i = 0; i < N; ++i) {
        for (int a = 0; a <= 401; ++a) {
            for (int b = 0; b <= 8; ++b) {
                for (int c = 0; c <= 8; ++c) {
                    for (int d = 0; d <= 3; ++d) {
                        int e = s + 1 - a - b - c - d;
                        if (e < 0) continue;
                        if (!dp[i][a][b][c][d]) continue;

                        for (int a2 = 0; a2 <= a; ++a2) {
                            for (int b2 = 0; b2 <= b && a2 + b2 <= A[i]; ++b2) {
                                for (int c2 = 0; c2 <= c && a2 + b2 +c2 <= A[i]; ++c2) {
                                    for (int d2 = 0; d2 <= d && a2 + b2 +c2 + d2 <= A[i]; ++d2) {
                                        for (int e2 = 0; e2 <= e && a2 + b2 +c2 + d2 + e2 <= A[i]; ++e2) {
                                            int s2 = a2 + b2 + c2 + d2 + e2;
                                            int r2 = A[i] - s2;
                                            long long fact = COM(a, a2);
                                            fact = (fact * COM(b, b2)) % MOD;
                                            fact = (fact * COM(c, c2)) % MOD;
                                            fact = (fact * COM(d, d2)) % MOD;
                                            fact = (fact * COM(e, e2)) % MOD;
                                            fact = (fact * COM(r2 + s2 - 1, r2)) % MOD;
                                            
                                            int na = a + c + d - a2 - c2 - d2;
                                            int nb = (a2 + e2) * 2 + b2 + c2;
                                            int nc = b2 + c2 + d2 * 2;
                                            int nd = r2;
                                            //int ne = b + e - b2 - e2;
                                            dp[i+1][na][nb][nc][nd] += dp[i][a][b][c][d] * fact % MOD;
                                            dp[i+1][na][nb][nc][nd] %= MOD;
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        s += A[i];
    }
    long long res = 0;
    for (int b = 0; b <= 8; ++b) res = (res + dp[N][0][b][0][0]) % MOD;
    for (int i = 0; i < N; ++i) res = (res * fac[A[i]]) % MOD;
    cout << res << endl;
}