けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ ???? Counting Angels (KUPC 2020 G)

こういう条件を言い換えながら数え上げる問題好き!

問題概要

タプリスちゃんは現在、長さ 1 の数列  A=(1) を持っている。

タプリスちゃんは  A に対して、以下のいずれかの操作を選んで行うことを  N 回繰り返すことにした。

  •  A の末尾に  1 または  M を追加する
  •  1 \le i \lt |A| である整数 i を 1 つ選択し、 a_{i} \lt x \lt a_{i}+1 または  a_{i} \gt x \gt a_{i}+1
が成り立つような整数  x a_{i} と a_{i}+1 の間に追加する

 i 回目の操作を行ったあとの数列  A S_{i} と書くことにする。数列の列  S_{1}, S_{2}, \dots, S_{N} としてありえるものの種類数を 998244353 で割った余りを求めてください。

制約

  •  1 \le N \le 3000
  •  2 \le M \le 10^{8}

考えたこと

これ、次の 2 つのどちらを問いているかによって、解き方が大きく変わってくる!

  • 数列の列  S_{1}, S_{2}, \dots, S_{N} の個数
  • 最終的に出来上がる数列  S_{N} の個数

どちらかというと、後者を問いかけるような問題が多くて、前者のように「操作列自体を数え上げよ」という問題は少ないイメージ。後者であれば「最初に判定問題を解く」という手法がかなり有効になる。

さて、操作列を数え上げよ、と言われるとそれはそれでややこしいので、条件をいい感じに言い換えたい。まず、操作は次の 3 パターンに分類できる。

  • A: 数列  S の末尾の値を  v (= 1 または M) として、末尾に  v を挿入する
  • B: 数列  S の末尾の値を  v として、末尾に  v じゃない方を挿入する ([tex v = 1] なら  M v = M なら  1)
  • C: 数列のどこかに適切な値を挿入する

そして問題は次のように言い換えられる。


'A', 'B', 'C' のみからなる長さ  N の文字列であって、以下の条件を満たすものを数え上げよ。

  • 任意の 'C' に対して、その前にある 'B' の個数を  b, 'C' の個数を  c としたときに、 b(M-2) - c \ge 0 が成立
  • ただし 'C' を挿入するときは、 b(M-2) - c 倍する

ここで、 B を 1 回挟むごとに、挿入できる (場所, 値) のペアの個数が  M-2 箇所増えることに注意する。 C を 1 回挟むごとに、そのペアの個数は 1 減少する。

A をなくして DP へ

さらに、上記の条件を満たす文字列において 'A' の存在は飾りでしかない。よって、次のように考えれば OK。

  • 'A' の個数を  i 個と決め打ちしたとき
  • 文字中の 'A' の箇所を決め打つ方法は  {}_{N}{\rm C}_{i} 箇所あって、
  • 'B' と 'C' のみを  N-i 個並べる問題に帰着する

'B' と 'C' のみを並べる問題は、次のような DP で解ける。

  • dp[ i ][ j ] := 'B' を j 文字使いながら 'B' と 'C' を合計 i 文字並べる方法の個数 (ただし 'C' を並べるときは  b(M-2) - c をかける)

計算量は  O(N^{2}) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

//const int MOD = 1000000007;
const int MOD = 998244353;
using mint = Fp<MOD>;

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

BiCoef<mint> bc;
mint solve(long long N, long long M) {
    M -= 2;
    vector<vector<mint>> dp(N+2, vector<mint>(N+1, 0));
    dp[0][0] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = 0; j <= i; ++j) {
            if (dp[i][j] == 0) continue;
            long long s = M * j - (i - j);
            if (j + 1 <= N) {
                dp[i+1][j+1] += dp[i][j];
            }
            if (j - 1 >= 0 && s >= 0) {
                dp[i+1][j] += dp[i][j] * s;
            }
        }
    }
    mint res = 0;
    for (int i = 0; i <= N; ++i) {
        mint tmp = 0;
        for (int j = 0; j <= i; ++j) {
            tmp += dp[i][j];
        }
        res += tmp * bc.com(N, i);
    }
    return res;
}

int main() {
    bc.init(1100000);
    long long N, M;
    while (cin >> N >> M) cout << solve(N, M) << endl;
}