けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 043 D - Merge Triplets (1200 点)

めちゃくちゃ楽しかった

問題へのリンク

問題概要

長さ  3 の数列を  N 個用意する。ただしこれらのなす  3N 個の値は  1, 2, \dots, 3N が 1 個ずつ登場するようにする。これらの数列から以下の操作を繰り返して、長さ  3N の順列を作る。

  • 空にならずに残っている数列のうち、先頭の要素のみをみたときの最小値を pop して、
  • それを現在形成されている順列の最後尾に加える

作られうる順列の個数を素数  M で割ったあまりを求めよ。

制約

  •  1 \le N \le 2000

考えたこと

入力が実質的に  N の 1 個だけという好きなタイプの問題。埋め込みを封じるために mod も入力として指定されているわけだ。

さて、こういう「操作の結果出来上がるものが何通りありますか」という問題では、操作過程をまともに追いかけてはいけないと痛感している。まずはじめに「こういう順列を作れますか?」という判定問題を解くのだ。それがわかりやすい形で解けると数え上げもできる、というパターンはあまりにも多い。

というわけで、作れる順列の条件を考えよう。色々やっているとまず

  • 3, 8, 5, 4, 2, 6, 1, 7, 9

とかを例にとったとき、(8, 5, 4, 2) という並びで既に詰んでいることがわかる。なぜなら、8 を拾う時点で (5, 4, 2) はむき出しになってはいけないので他レーンに置くことができないのだ。これらは 8 の後ろに並んでなければならない。したがって

  • 要素 x に対して、そこから先に最初に x より大きい要素が出てくるまでの間に、x を含めて 4 個以上の要素がある状態は実現できない

ということがわかった。ここから、

  • 要素 x に対して、最初に x より大きいものが出てくるまでをひとまとめにしたブロック

に分けて考えることにした。このとき、大きさ 2 のブロックを同じレーンに複数個入れることもできないこともわかる。よって必要条件を書き出すと


  • ブロックの大きさは 3 以下でなければならない
  • 大きさが 2 以上のブロックは  N 個以下でなければならない

というのが抽出された。そして色々手を動かしてみると、これが十分条件っぽい。実際

  • 大きさ 3 のブロックは新規レーンに突っ込んでおけばよい
  • 大きさ 2 のブロックは、どこか隙間か、新規レーンに突っ込んでおけば良い
  • 大きさ 1 のブロックは好きにしてよい

という方法で問題ないことがわかる。基本的にブロックの先頭要素は単調増加なので、ブロック単位でまとめて順列に挿入されていくことは保証されるし、その順序がブロック順になることも保証される。

他視点:最大の要素に着目

解説でも、他の多くの方のコメントを見ても、「まず最大の要素に着目して...」という視点から、同じ条件に行き着いた方は多かったみたいだ。僕はあまりそこを意識できていなかった。アルメリアさんの記事にこの着眼点からの導出がある。

DP へ

以上の条件を抽出できたらもう簡単だ。

  • dp[ i ][ j ] := 1, 2, ..., i の順列であって、大きさ 2 以上のブロックが j 個あるような場合の数

とすると、遷移は新規ブロックサイズが 1, 2, 3 の場合に場合分けして挿入 DP すれば OK。

  • dp[ i + 1 ][ j ] += dp[ i ][ j ] (新規挿入が最大値でなければならない)
  • dp[ i + 2 ][ j + 1 ] += dp[ i ][ j ] × (i + 1) (新規挿入の先頭は最大値だが、2 個目の値は i + 1 通り考えられる)
  • dp[ i + 3 ][ j + 1 ] += dp[ i ][ j ] × (i + 2)(i + 1)

という感じ。計算量は O(N2) となる。

#include <bits/stdc++.h>
using namespace std;

// dynamic modint
vector<int> MODS = { 1000000007 }; // 実行時に決まる
template<int IND = 0> struct Fp {
    long long val;
    
    int MOD = MODS[IND];
    constexpr Fp(long long v = 0) noexcept : val(v % MODS[IND]) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<IND>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<IND>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<IND> modpow(const Fp<IND> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

using mint = Fp<>;
mint solve(int N) {
    vector<vector<mint>> dp(N*3+1, vector<mint>(N+1, 0));
    dp[0][0] = 1;
    for (long long i = 0; i <= N*3; ++i) {
        for (long long j = 0; j <= N; ++j) {
            if (i+1 <= N*3) {
                dp[i+1][j] += dp[i][j];
            }
            if (i+2 <= N*3 && j+1 <= N) {
                dp[i+2][j+1] += dp[i][j] * (i+1);
            }
            if (i+3 <= N*3 && j+1 <= N) {
                dp[i+3][j+1] += dp[i][j] * (i+2) * (i+1);
            }
        }
    }
    mint res = 0;
    for (int j = 0; j <= N; ++j) res += dp[N*3][j];
    return res;
}

int main() {
    int N;
    cin >> N >> MODS[0];
    cout << solve(N) << endl;
}