けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 045 C - Range Set (橙色, 800 点)

面白かった!!

問題概要

すぬけくんは長さ  N の文字列  x を持っている。最初、 x のすべての文字は 0 である。

すぬけくんは,以下の 2 種類の操作を好きな順序で好きな回数行うことができます.

  •  x の連続する  A 文字を選んで,それらをすべて 0 にする.
  •  x の連続する  B 文字を選んで,それらをすべて 1 にする.

すぬけくんが操作を終えたあとの  x としてありうるものの個数を 1000000007 で割ったあまりを求めよ。

制約

  •  1 \le A, B \le N \le 5000

考えたこと

またしても「最終的に出来上がるものを特徴付ける」というタイプの問題!!!最近の AGC-C でめちゃくちゃ多いパターン!!!

そして今回の問題は

  • 区間に対する操作
  • 上書き操作

であるという特徴がある。区間に対する操作は、ものによっては差分をとって扱ったり平面操作的に処理したりが有効だったりする。しかし今回の問題は、その方向性は意味なさそう。

上書き操作とは、過去の履歴に依らないということ。そういう操作は「後ろから見る」というのが有効に思う。後ろから見ると、操作は次のように言い換えられる。


  • 初期状態では、空の  N マスが与えられる
  • スタンプ (0 と 1) を押したとき
    • すでに数値が押されたマスの数値は変わらない
    • 数値が押されていないマスの数値は、押したスタンプの数値に書き変わる

つまり操作列を逆から見ると「変更不可」の操作に早変わりするのだ。この発想をする問題としては、以下のやつとか (むしろ今回の問題は、これらの問題の上位互換と言える)。

drken1215.hatenablog.com

drken1215.hatenablog.com

さて、改めて必要条件を考えていこう。まず  A \le B と考えて一般性を失わない。このとき、

  • 長さ  B 以上の「数値 0 区間」を含むものは作れる
  • 長さ  B 以上の「数値 1 区間」を含むものは作れる

ということは言える。長さが  B 以上の同数値区間を作るのは容易で、さらにそこから左右に 1 マスずつ好きな数値を埋められるからだ。

これ以降、長さ  B 以上の同一数値区間が存在しないと仮定して考えてみる。そうすると、とりあえず 1 は存在しなければならない (全部 0 はダメ)。よって、スタンプ 1 が押される瞬間が存在する (このとき、すでに 0 になっている箇所からはみ出すことがある)。

そして、一度スタンプ 1 を押してしまえば、その左右はどんなものでも作れる。よって、まとめると、以下のものが作れることになる。


  • 長さ  B 以上の「数値 0 区間」を含むもの
  • 長さ  B 以上の「数値 1 区間」を含むもの
  • 長さ  A 以上の「数値 0 区間」をすべて 1 で埋めた場合に、長さ  B 以上の「数値 1 区間」が形成されるようなもの

あとはこれが数え上げられれば OK!!!

数え上げパート

余事象を数える方が楽なので、余事象を考える。余事象は以下の条件をすべて満たすものとなる。


  • 長さが  B 以上の同一数値区間は含まない
  • 長さが  A 以上の「数値 0 区間」をすべて 1 で埋めたとしても、長さ  B 以上「数値 1 区間」が形成されることはない

これは次のような DP でできそうだ。

  • dp0[ i ] := 条件を満たす長さ i の文字列であって、最後の値が 0 であるようなものの個数
  • dp1[ i ] := 条件を満たす長さ i の文字列であって、最後の値が 1 であるようなものの個数

このとき次の遷移になる。

dp0[ i + j ] += dp1[ i ] (j < A)
dp1[ i + j ] += dp0[ i ] × (f0[ j ] + f1[ j ]) (j < B で i = 0 または i + j = N)
dp1[ i + j ] += dp0[ i ] × f1[ j ] (j < B で i != 0 かつ i + j != N) 

ここで、

  • f0[ i ] := 左端が 1 で右端が 0 であるような長さ i の文字列であって、その内部に含まれる数値 0 の区間の長さがすべて  A 以上  B であるようなものの個数
  • f1[ i ] := 両端が 1 であるような長さ i の文字列であって、その内部に含まれる数値 0 の区間の長さがすべて  A 以上  B であるようなものの個数

としている。最後は f0[ j ], f1[ j ] を求める問題へと帰着された。これは単純な DP であらかじめ求められる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

pair<vector<mint>, vector<mint>> pre(int A, int B) {
    vector<mint> dp0(B, 0), dp1(B, 0);
    dp0[0] = 1;
    for (int i = 0; i < B; ++i) {
        for (int j = 1; i+j < B; ++j) dp1[i+j] += dp0[i];
        for (int j = A; i+j < B; ++j) dp0[i+j] += dp1[i];
    }
    return {dp0, dp1};
}

mint solve(int N, int A, int B) {
    if (A >= B) swap(A, B);
    auto pai = pre(A, B);
    auto f0 = pai.first, f1 = pai.second;
    vector<mint> dp0(N+1, 0), dp1(N+1, 0);
    dp0[0] = dp1[0] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = 1; j < A && i+j <= N; ++j) dp0[i+j] += dp1[i];
        for (int j = 1; j < B && i+j <= N; ++j) {
            if (i == 0 || i+j == N) dp1[i+j] += dp0[i] * (f0[j] + f1[j]);
            else dp1[i+j] += dp0[i] * f1[j];
        }
    }
    return modpow(mint(2), N) - dp0[N] - dp1[N];
}

int main() {
    int N, A, B;
    cin >> N >> A >> B;
    cout << solve(N, A, B) << endl;
}

数え上げパートの僕自身が最初にやった方法

最初にやった方法はちょっと煩雑だった。次のような DP でやった。ここで、「1 の実質的な連続長」とは、「0 が A 箇所以上連続している箇所」をすべて 1 で置き換えた場合についての、1 の連続している長さを指すものとする。

  • dp0[ i ][ j ] := 長さ i + j の文字列であって、最後の文字が 0 で、最後の「1 の実質的な連続長」が j であるようなものの個数
  • dp1[ i ][ j ] := 長さ i + j の文字列であって、最後の文字が 1 で、最後の「1 の実質的な連続長」が j であるようなものの個数

このような DP を自然に立てると  O(N^{3}) の DP ができる。それを累積和を用いて高速化するなどすると、 O(N^{2}) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

mint solve(int N, int A, int B) {
    if (A >= B) swap(A, B);
    vector<vector<mint>> dp0(N+1, vector<mint>(N+1, 0)), dp1 = dp0;
    vector<vector<mint>> sdp0(N+1, vector<mint>(N+2, 0)), sdp1 = sdp0;
    auto sum = [&](int ind, int i, int left, int right) -> mint {
        chmax(left, 0), chmin(right, N+1);
        if (left >= right) return mint(0);
        if (ind == 0) return sdp0[i][right] - sdp0[i][left];
        else return sdp1[i][right] - sdp1[i][left];
    };
    dp0[0][0] = dp1[0][0] = sdp0[0][1] = sdp1[0][1] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = 0; j < B; ++j) {
            if (j == 0) {
                for (int k = 0; k < i; ++k) {
                    dp0[i][j] += sum(1, k, i-k-A+1, min(i-k, B));
                }
            }
            else {
                dp0[i][j] += sum(1, i, j-B+1, j-A+1);
                dp1[i][j] += sum(0, i, j-B+1, j);
            }
            sdp0[i][j+1] = sdp0[i][j] + dp0[i][j];
            sdp1[i][j+1] = sdp1[i][j] + dp1[i][j];
        }
    }
    mint res = 0;
    for (int j = 0; j <= B; ++j) res += dp0[N-j][j] + dp1[N-j][j];
    return modpow(mint(2), N) - res;
}

int main() {
    int N, A, B;
    cin >> N >> A >> B;
    cout << solve(N, A, B) << endl;
}