けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 046 C - Shift (黄色, 800 点)

こういう問題めっちゃ好き!!!

問題概要

'0' と '1' のみからなる長さ  N の文字列が与えられる。以下の操作を  0 回以上  K 回以下まで行うことができる。

  • i < j であって S[ i ] = '0'、S[ j ] = '1' であるような (i, j) を選ぶ
  • S[ j ] の '1' を削除して、S[ i ] の '0' の直前に挿入する

操作によって作りうる文字列として考えられるものの個数を 998244353 で割ったあまりを求めよ。

制約

  •  1 \le |S| \le 300
  •  0 \le K \le 10^{9}

考えたこと

こういう「 K 回以下の操作で作れるものは何通りか?」という問題では、

  • 結局どういう 0-1 列を作ることができるのか
  • 文字列  T が与えられたときに、 S T にするための最小手数はどうやったら求められるのか
  • 仮に  K が無限だったらどういうのが作れるのか

といったことを検討するのが良いと思われる。

操作の言い換え

今回の文字列に対して、以下のように「各 1 の前に何個の 0 があるか」で変換する。ただし「先頭から連続する 1」と「末尾から連続する 0」は一切影響しないので除外することにした。

1111010001100100
↓
(1, 4, 4, 6)

このような数列に変換したとき、操作は次のように言い換えられる。


  • 数列の任意の値を減少させる
  • 数列を小さい順に並び替える

元の文字列と数列 (小さい順ソート) とは一対一対応なので、この変換後の数列に対する操作について考えても解ける。

このとき、「元の数列  A を数列  B にするまでの最小回数」は次のように求められることになる (数列の長さを  N とする)。

  • もし  A_{i} \lt B_{i} となる  i が存在したら不可能
    • つまり、操作によって絶対に作れない
  •  A B とで一致する箇所の個数を  k としたとき、必要最小回数は  N - k 回となる
    • つまり、 K \ge N - k ならば作れる

以上より、初期状態の数列を  A として、

  •  B_{i} \le A_{i}
  •  0 \le B_{0} \le B_{1} \dots \le B_{N-1}
  •  A B とで一致している箇所が  \min(N - K, 0) 箇所以上

という条件を満たすような数列  B の個数を数えれば OK。

DP へ

以上を踏まえて、次のような DP をすることにした。

  •  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack := 数列  B の最初の  i 個の値を決める方法のうち、すべての要素の値が  v 以下であって、 A とすでに一致している箇所が  j 箇所であるようなものの個数

この DP の遷移を愚直に実装すると、次のような遷移になって  O(|S|^{4}) の計算量になる。 a A_{i} = v+1 となる  i の個数とする。

  •  {\rm dp}\lbrack v+1 \rbrack \lbrack i \rbrack \lbrack j \rbrack +=  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack
  •  {\rm dp}\lbrack v+1 \rbrack \lbrack i + 1 \rbrack \lbrack j + 1 \rbrack +=  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack ( A_{i} \ge v+1 のみ)
  •  {\rm dp}\lbrack v+1 \rbrack \lbrack i + 2 \rbrack \lbrack j + 2 \rbrack +=  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack ( A_{i} \ge v+1 のみ)
  • ...
  •  {\rm dp}\lbrack v+1 \rbrack \lbrack i + a \rbrack \lbrack j + a \rbrack +=  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack ( A_{i} \ge v+1 のみ)
  •  {\rm dp}\lbrack v+1 \rbrack \lbrack i + a + 1 \rbrack \lbrack j + a \rbrack +=  {\rm dp}\lbrack v \rbrack \lbrack i \rbrack \lbrack j \rbrack ( A_{i} \ge v+1 のみ)
  • ...

累積和やいもす法を用いることで、 O(|S|^{3}) となった。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

mint solve(const string &S, int K) {
    vector<int> A;
    int num = 0;
    for (auto c : S) {
        if (c == '0') ++num;
        else if (num) A.push_back(num);
    }
    if (A.empty()) return 1;

    int N = (int)A.size(), V = A.back();
    map<int,int> ma;
    for (auto v : A) ma[v]++;
    vector<vector<mint>> dp(N+1, vector<mint>(N+1, 0)), ndp, ndp1, ndp2;
    for (int i = 0; i <= N; ++i) dp[i][min(i, ma[0])] = 1;
    for (int v = 0; v < V; ++v) {
        ndp.assign(N+1, vector<mint>(N+1, 0));
        ndp1.assign(N+1, vector<mint>(N+1, 0));
        ndp2.assign(N+1, vector<mint>(N+1, 0));
        int a = ma[v+1];
        for (int i = 0; i <= N; ++i) {
            if (A[i] <= v) continue;
            for (int j = 0; j <= i; ++j) {
                if (dp[i][j] == 0) continue;
                if (i+1 <= N) ndp1[i+1][j+1] += dp[i][j];
                if (i+a+1 <= N) {
                    ndp1[i+a+1][j+a+1] -= dp[i][j];
                    ndp2[i+a+1][j+a] += dp[i][j];
                }
            }
        }
        for (int i = 0; i <= N; ++i) {
            for (int j = 0; j <= N; ++j) {
                if (i-1 >= 0 && j-1 >= 0) ndp1[i][j] += ndp1[i-1][j-1];
                if (i-1 >= 0) ndp2[i][j] += ndp2[i-1][j];
                ndp[i][j] = dp[i][j] + ndp1[i][j] + ndp2[i][j];
            }
        }
        swap(dp, ndp);
    }
    mint res = 0;
    for (int j = max(N-K, 0); j <= N; ++j) res += dp[N][j];
    return res;
}

int main() {
    string S;
    int K;
    cin >> S >> K;
    cout << solve(S, K) << endl;
}