この平方分割のやり方はちゃんとマスターしたい
問題概要
種類のレベル 1 の宝石がある。各種類のレベル 1 の宝石は無限個ある。
個の宝石を合成することで、レベル
の宝石を作ることができる。ただし、その
個の宝石は次の条件を満たす必要がある:
- すべてレベル
未満である
- 2 以上の整数
については「レベル
の宝石」は高々 1 個である
- どの 2 個も種類が異なる (レベル 2 以上の宝石の種類が異なるとは、原料となった宝石の種類が異なることを意味する)
レベル の宝石を何種類作れるかを 998244353 で割ったあまりで答えよ。
制約
考えたこと
いきなりレベル の宝石の種類数を数えるのではなく、レベル
の宝石の種類数を数えることにする。なお、
dp[i]
を、レベル の種類数とする
レベル 1
種類である
レベル 2
dp[2]
= 種類である
レベル 3
- レベル 1 が 3 個の場合:
種類
- レベル 1 が 2 個、レベル 2 が 1 個の場合:
dp[2]
種類
レベル 4
- レベル 1 が 4 個の場合:
種類
- レベル 1 が 3 個、レベル 2 が 1 個の場合:
dp[2]
種類 - レベル 1 が 3 個、レベル 3 が 1 個の場合:
dp[3]
種類 - レベル 1 が 2 個、レベル 2 が 1 個、レベル 3 が 1 個の場合:
dp[2]
dp[3]
種類
畳み込みへ
ここまで考えてみて「畳み込みっぽい」式になりそうだとわかる。
- レベル 1 を表す多項式:
- レベル 2 を表す多項式:
dp[1]
- レベル 3 を表す多項式:
dp[2]
- ...
- レベル
を表す多項式:
dp [n-1]
としたとき、
dp[n]
=
と表せることになる。ここまでで 解法にはなった。具体的には、
for (int i = 1; i <= N; ++i) { dp[i] = f[i]; h[i] = {1, dp[i]}; // h[i] = 1 + dp[i]x f *= h[i]; }
というように処理していく。
もし仮に、各 の係数が静的に決まるならば、よくある「二分木のような計算順序」によって、
の計算量となる。
しかし今回は、 の係数が
によって動的に決まるため、難しいな......という気持ちになっていた。以前に分割統治 FFT というテクニックを学んでいて、それが使えないかと考えたけどわからなかった。
平方分割
公式解説はちょっと天才だと思った。でもコンテスト本番中に通されたコードはほとんど平方分割解法だった。それは無理ない解法だと感じたので、そっちをマスターすることにした。
毎回 f *= h[i]
という多項式演算をする代わりに、多項式列 の後ろの方を一時的に計算して貯めておくための多項式
を用意する方法が考えられる。
このとき、dp[i]
の値は、 の
次の項を計算すればよいわけだが、その計算量は
で抑えられるのだ。
回ごとに、
f *= g
をすることにした場合、
- 毎回の
dp[i]
を求める計算量: f *= g
の計算量:回 ×
=
ということで、全体の計算量は となる。
なお、 とするともう少し早くなって、
になる。
コード
#include <bits/stdc++.h> #include <atcoder/modint> #include <atcoder/convolution> using namespace std; using mint = atcoder::modint998244353; // v に (x + a) をかける void update(vector<mint>& v, int& deg, mint a) { ++deg; for (int d = deg; d >= 1; --d) v[d] += v[d-1] * a; } int main() { long long N, A; cin >> N >> A; // f, g vector<mint> f(N + 1, 0); f[0] = 1; for (int i = 1; i <= min(N, A); ++i) f[i] = f[i-1] * (A-i+1) / i; // 貯める計算過程を表す多項式を初期化する関数 int B = 2000; vector<mint> tmp; int deg = 0; auto init = [&]() -> void { tmp.assign(B, 0), tmp[0] = 1, deg = 0; }; // 計算過程を貯めながら計算する mint a = A; init(); for (int n = 2; n <= N; ++n) { a = 0; for (int d = max(0, (int)(n+1-f.size())); d < tmp.size(); ++d) { if (n - d < 0) break; a += tmp[d] * f[n - d]; } update(tmp, deg, a); if (deg == tmp.size() - 1) { f = atcoder::convolution(f, tmp); f.resize(N + 1); init(); } } cout << a.val() << endl; }