けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 281 Ex - Alchemy (赤色, 600 点)

この平方分割のやり方はちゃんとマスターしたい

問題概要

 A 種類のレベル 1 の宝石がある。各種類のレベル 1 の宝石は無限個ある。

 n 個の宝石を合成することで、レベル  n の宝石を作ることができる。ただし、その  n 個の宝石は次の条件を満たす必要がある:

  • すべてレベル  n 未満である
  • 2 以上の整数  x については「レベル  x の宝石」は高々 1 個である
  • どの 2 個も種類が異なる (レベル 2 以上の宝石の種類が異なるとは、原料となった宝石の種類が異なることを意味する)

レベル  N の宝石を何種類作れるかを 998244353 で割ったあまりで答えよ。

制約

  •  1 \le N \le 2 \times 10^{5}
  •  1 \le A \le 10^{9}

考えたこと

いきなりレベル  N の宝石の種類数を数えるのではなく、レベル  2, 3, \dots の宝石の種類数を数えることにする。なお、dp[i] を、レベル  i の種類数とする

レベル 1

 A 種類である

レベル 2

dp[2] =  {}_{A}\mathrm{C}_{2} 種類である

レベル 3
  • レベル 1 が 3 個の場合: {}_{A}\mathrm{C}_{3} 種類
  • レベル 1 が 2 個、レベル 2 が 1 個の場合: {}_{A}\mathrm{C}_{2} \times dp[2] 種類
レベル 4
  • レベル 1 が 4 個の場合: {}_{A}\mathrm{C}_{4} 種類
  • レベル 1 が 3 個、レベル 2 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[2] 種類
  • レベル 1 が 3 個、レベル 3 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[3] 種類
  • レベル 1 が 2 個、レベル 2 が 1 個、レベル 3 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[2]  \times dp[3] 種類

畳み込みへ

ここまで考えてみて「畳み込みっぽい」式になりそうだとわかる。

  • レベル 1 を表す多項式: f(x) = 1 +  {}_{A}\mathrm{C}_{1}x + {}_{A}\mathrm{C}_{2}x^{2} +  {}_{A}\mathrm{C}_{3}x^{3} + \dots + x^{A}
  • レベル 2 を表す多項式: h_{1}(x) = 1 + dp[1] x
  • レベル 3 を表す多項式: h_{2}(x) = 1 + dp[2] x^{2}
  • ...
  • レベル  n-1 を表す多項式: h_{n-1}(x) = 1 + dp [n-1] x^{n-1}

としたとき、

dp[n] =  \lbrack x^{n} \rbrack f(x)h_{1}(x)h_{2}(x) \dots h_{n-1}(x)

と表せることになる。ここまでで  O(N^{2}) 解法にはなった。具体的には、

for (int i = 1; i <= N; ++i) {
    dp[i] = f[i];
    h[i] = {1, dp[i]};  // h[i] = 1 + dp[i]x
    f *= h[i];
}

というように処理していく。

もし仮に、各  h_{i}(x) の係数が静的に決まるならば、よくある「二分木のような計算順序」によって、 O(N (\log N)^{2}) の計算量となる。

しかし今回は、 h_{i}(x) の係数が  h_{1}(x), h_{2}(x), \dots, h_{i-1}(x) によって動的に決まるため、難しいな......という気持ちになっていた。以前に分割統治 FFT というテクニックを学んでいて、それが使えないかと考えたけどわからなかった。

公式解説は分割統治 FFT

平方分割

公式解説はちょっと天才だと思った。でもコンテスト本番中に通されたコードはほとんど平方分割解法だった。それは無理ない解法だと感じたので、そっちをマスターすることにした。

毎回 f *= h[i] という多項式演算をする代わりに、多項式列  h_{1}, h_{2}, \dots h_{i} の後ろの方を一時的に計算して貯めておくための多項式  g を用意する方法が考えられる。

このとき、dp[i] の値は、 f \times g i 次の項を計算すればよいわけだが、その計算量は  \mathrm{deg}(g) で抑えられるのだ。 B = O(\sqrt{N}) 回ごとに、f *= g をすることにした場合、

  • 毎回の dp[i] を求める計算量: O(\sqrt{N})
  • f *= g の計算量: O(\sqrt{N}) 回 ×  O(N \log{N}) =  O(N \sqrt{N} \log{N})

ということで、全体の計算量は  O(N \sqrt{N} \log{N}) となる。

なお、 B = \sqrt{N \log N} とするともう少し早くなって、 O(N \sqrt{N \log N}) になる。

コード

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using mint = atcoder::modint998244353;

// v に (x + a) をかける
void update(vector<mint>& v, int& deg, mint a) {
    ++deg;
    for (int d = deg; d >= 1; --d) v[d] += v[d-1] * a;
}
int main() {
    long long N, A;
    cin >> N >> A;

    // f, g
    vector<mint> f(N + 1, 0);
    f[0] = 1;
    for (int i = 1; i <= min(N, A); ++i) f[i] = f[i-1] * (A-i+1) / i;

    // 貯める計算過程を表す多項式を初期化する関数
    int B = 2000;
    vector<mint> tmp;
    int deg = 0;
    auto init = [&]() -> void {
        tmp.assign(B, 0), tmp[0] = 1, deg = 0;
    };

    // 計算過程を貯めながら計算する
    mint a = A;
    init();
    for (int n = 2; n <= N; ++n) {
        a = 0;
        for (int d = max(0, (int)(n+1-f.size())); d < tmp.size(); ++d) {
            if (n - d < 0) break;
            a += tmp[d] * f[n - d];
        }
        update(tmp, deg, a);
        if (deg == tmp.size() - 1) {
            f = atcoder::convolution(f, tmp);
            f.resize(N + 1);
            init();
        }
    }
    cout << a.val() << endl;
}