こういう問題めっちゃ好き!!!
問題概要
'0' と '1' のみからなる長さ の文字列が与えられる。以下の操作を 回以上 回以下まで行うことができる。
- i < j であって S[ i ] = '0'、S[ j ] = '1' であるような (i, j) を選ぶ
- S[ j ] の '1' を削除して、S[ i ] の '0' の直前に挿入する
操作によって作りうる文字列として考えられるものの個数を 998244353 で割ったあまりを求めよ。
制約
考えたこと
こういう「 回以下の操作で作れるものは何通りか?」という問題では、
- 結局どういう 0-1 列を作ることができるのか
- 文字列 が与えられたときに、 を にするための最小手数はどうやったら求められるのか
- 仮に が無限だったらどういうのが作れるのか
といったことを検討するのが良いと思われる。
操作の言い換え
今回の文字列に対して、以下のように「各 1 の前に何個の 0 があるか」で変換する。ただし「先頭から連続する 1」と「末尾から連続する 0」は一切影響しないので除外することにした。
1111010001100100 ↓ (1, 4, 4, 6)
このような数列に変換したとき、操作は次のように言い換えられる。
- 数列の任意の値を減少させる
- 数列を小さい順に並び替える
元の文字列と数列 (小さい順ソート) とは一対一対応なので、この変換後の数列に対する操作について考えても解ける。
このとき、「元の数列 を数列 にするまでの最小回数」は次のように求められることになる (数列の長さを とする)。
- もし となる が存在したら不可能
- つまり、操作によって絶対に作れない
- と とで一致する箇所の個数を としたとき、必要最小回数は 回となる
- つまり、 ならば作れる
以上より、初期状態の数列を として、
- と とで一致している箇所が 箇所以上
という条件を満たすような数列 の個数を数えれば OK。
DP へ
以上を踏まえて、次のような DP をすることにした。
- := 数列 の最初の 個の値を決める方法のうち、すべての要素の値が 以下であって、 とすでに一致している箇所が 箇所であるようなものの個数
この DP の遷移を愚直に実装すると、次のような遷移になって の計算量になる。 を となる の個数とする。
- +=
- += ( のみ)
- += ( のみ)
- ...
- += ( のみ)
- += ( のみ)
- ...
累積和やいもす法を用いることで、 となった。
コード
#include <bits/stdc++.h> using namespace std; // modint template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() const { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept { if (n == 0) return 1; if (n < 0) return modpow(modinv(r), -n); auto t = modpow(r, n / 2); t = t * t; if (n & 1) t = t * r; return t; } friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } return Fp<MOD>(u); } }; const int MOD = 998244353; using mint = Fp<MOD>; mint solve(const string &S, int K) { vector<int> A; int num = 0; for (auto c : S) { if (c == '0') ++num; else if (num) A.push_back(num); } if (A.empty()) return 1; int N = (int)A.size(), V = A.back(); map<int,int> ma; for (auto v : A) ma[v]++; vector<vector<mint>> dp(N+1, vector<mint>(N+1, 0)), ndp, ndp1, ndp2; for (int i = 0; i <= N; ++i) dp[i][min(i, ma[0])] = 1; for (int v = 0; v < V; ++v) { ndp.assign(N+1, vector<mint>(N+1, 0)); ndp1.assign(N+1, vector<mint>(N+1, 0)); ndp2.assign(N+1, vector<mint>(N+1, 0)); int a = ma[v+1]; for (int i = 0; i <= N; ++i) { if (A[i] <= v) continue; for (int j = 0; j <= i; ++j) { if (dp[i][j] == 0) continue; if (i+1 <= N) ndp1[i+1][j+1] += dp[i][j]; if (i+a+1 <= N) { ndp1[i+a+1][j+a+1] -= dp[i][j]; ndp2[i+a+1][j+a] += dp[i][j]; } } } for (int i = 0; i <= N; ++i) { for (int j = 0; j <= N; ++j) { if (i-1 >= 0 && j-1 >= 0) ndp1[i][j] += ndp1[i-1][j-1]; if (i-1 >= 0) ndp2[i][j] += ndp2[i-1][j]; ndp[i][j] = dp[i][j] + ndp1[i][j] + ndp2[i][j]; } } swap(dp, ndp); } mint res = 0; for (int j = max(N-K, 0); j <= N; ++j) res += dp[N][j]; return res; } int main() { string S; int K; cin >> S >> K; cout << solve(S, K) << endl; }