けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 040 C - Neither AB nor BA (橙色, 800 点)

これまた楽しい数え上げ!!!
解説があまりにも天才だけど、解説の方法が思いつかなくても一応できた!!!

問題概要

"A", "B", "C" のみからなる長さ  N の文字列であって、以下の条件を満たすものの個数を 998244353 で割ったあまりを求めよ。

  • 以下の操作を繰り返して行うことで空文字列にすることができる
    • 連続する 2 文字であって "AB" と "BA" でないものを削除する

制約

  •  2 \le N \le 10^{7}
  •  N は偶数

解法 (1):僕の考えたこと

例によって、まず判定問題を解く!!!つまり、与えられた文字列  S が空文字にできるかどうかを判定する問題を解く。まず、

  • "AA" が連続している箇所があったら Greedy に消して良い
  • "BB" が連続している箇所があったら Greedy に消して良い

ということが言えそうだと思った。なぜなら、"AA" という箇所があったとき、"AA" の部分で消さないと仮定すると、操作過程で "CAAC" の状態が存在して、"CA" と "AC" をそれぞれ消すしかないが、"C" をこんな形で使用するのはもったいない (厳密な証明もおそらくできる)。さらに、"...ABABCABAB..." とかについては "C" で左右どちらかを消すと連鎖的に消せることなどに着目した。

それらの観察から、作れる文字列が次のように特徴付けられそうだとなった。


  • まず "A" と "B" のみからなる文字列を用意する ("AA" と "BB" をすべて削除すると長さが  a になるとする)
  • その文字列の任意の箇所に "C" を挿入していく。ただし、"C" を挿入した箇所より先の部分については "A" と "B" を反転する ("C" を挿入する回数を  b とする)
  • ただし、 a \le b を満たさなければならない

まず一般に、「"C" を挿入した上でそれよりあとの "A" と "B" を反転する」という操作を行うと、「最終的に消せる "A" や "B" の個数」が 1 だけ増加する (元々空にできる場合は不変)。よって  a \le b を満たせば空文字列にすることができる。

逆に空文字列にできるような文字列は、その操作手順を反転して見方を変えると、上のような操作列によって再現できる。

よって、 C の個数を  k としたとき、

  • "A" と "B" からなる長さ  N-k の文字列であって、"AA" と "BB" を削除して得られる文字列の長さが  k 個以下であるものの個数を  f(k) として
  •  {}_{N}{\rm C}_{k} \times f(k) の総和を求めればよい

ということがわかる。 f(k) を求める問題へと帰着された。

f(k) を求める

"A" と "B" のみからなる長さ  N-k の文字列であって、"AA" と "BB" を削除できる限り削除して得られる文字列の長さ (これを文字列のスコアを呼ぶことにする) が  k 以下であるものの個数を求める問題へと帰着された。

まず愚直には次のような DP でできる。

  •  {\rm dp} \lbrack i \rbrack \lbrack j \rbrack := 長さ  i の文字列であって、スコアが  j であるものの個数

このとき、遷移は次のように書ける。

  •  {\rm dp} \lbrack i \rbrack \lbrack 0 \rbrack =  {\rm dp} \lbrack i - 1 \rbrack \lbrack 1 \rbrack
  •  {\rm dp} \lbrack i \rbrack \lbrack 1 \rbrack =  2{\rm dp} \lbrack i - 1 \rbrack \lbrack 0 \rbrack + {\rm dp} \lbrack i - 1 \rbrack \lbrack 2 \rbrack
  •  {\rm dp} \lbrack i \rbrack \lbrack j \rbrack =  {\rm dp} \lbrack i -1 \rbrack \lbrack j -1 \rbrack + {\rm dp} \lbrack i - 1 \rbrack \lbrack j + 1 \rbrack ( j \ge 2)

このままだと  O(N^{2}) の計算量となる。しかし DP 遷移をよく観察すると、実は

  •  {\rm dp} \lbrack i \rbrack \lbrack 0 \rbrack =  {}_{i}{\rm C}_{i/2} ( i が偶数)、0 ( i が奇数)
  •  {\rm dp} \lbrack i \rbrack \lbrack j \rbrack =  2{}_{i}{\rm C}_{i/2+j} ( i-j が偶数)、0 ( i が奇数)

となっていることがわかった。よって実は次のようになる。

  •  f(0) = {}_{N}{\rm C}_{N/2}
  •  f(1) = {}_{N-1}{\rm C}_{N/2 - 1} + {}_{N-1}{\rm C}_{N/2}
  •  f(2) = {}_{N-2}{\rm C}_{N/2 - 2} + {}_{N-2}{\rm C}_{N/2 - 1} + {}_{N-2}{\rm C}_{N/2}
  • ...

 k \ge N/2 に対しては  f(k) = 2^{N-k} となる。あとは二項係数の連続する箇所の総和が高速に求められれば OK。これは、 f(k) を用いて  f(k+1) を表す方針によって高速にできる。

たとえば  N = 10 のとき、 k = 2 の場合から  k= 3 の場合を求めるのは次のようにできる。

  •  k = 2 の場合は  {}_{8}{\rm C}_{3} + {}_{8}{\rm C}_{4} + {}_{8}{\rm C}_{5} となる
  • これを式変形すると  {}_{7}{\rm C}_{2} + 2{}_{7}{\rm C}_{3} + 2{}_{7}{\rm C}_{4} + {}_{7}{\rm C}_{5} となる
  • よってこれに  {}_{7}{\rm C}_{2} {}_{7}{\rm C}_{5} を足して 2 で割れば、 k = 3 の場合である  {}_{7}{\rm C}_{2} + {}_{7}{\rm C}_{3} + {}_{7}{\rm C}_{4} + {}_{7}{\rm C}_{5} に一致する

以上を総合すると、 f(0), f(1), \dots, f(N) O(N) で計算できることがわかった。これに基づいて  \sum_{k = 0}^{N} {}_{N}{\rm C}_{k} f(k) O(N) で計算できる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

int main() {
    int N;
    cin >> N;
    BiCoef<mint> bc(N+10);

    mint res = 0, sum = 0;
    for (int k = 0; k <= N; ++k) {
        sum = (sum + bc.com(N-k, N/2-k) * 2) / 2;
        res += bc.com(N, k) * sum;
    }
    cout << res << endl;
}

解法 (2):想定解法

これ天才すぎる!!!でも確かに僕の解法も本質的には同一なのかもしれない。

  • 奇数番目の "A" と "B" を入れ替える

という変換を行った先では、操作内容は次のように言い換えられる

「"AA" と "BB" は削除できない。それ以外は削除できる。」

とてもわかりやすくなった。そうすると、可能文字列の特徴づけも容易になる。次のようになる。


以下をともに満たす文字列を数え上げよ


僕の二項係数の式も、確かに見方によってはこの条件を満たすものを数え上げるものになっていることがわかる (C の個数を固定して、A の個数として満たすべき範囲で二項係数を足していく)。

しかし、もっと数えやすい数え方がある。

全体の個数 (3^N 通り)
- (A が過半数を占める文字列の個数)
- (B が過半数を占める文字列の個数)

を計算すれば OK。計算量は同じく  O(N) でできる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

int main() {
    int N;
    cin >> N;
    BiCoef<mint> bc(N+1);
    vector<mint> two(N+1, 1);
    for (int i = 1; i <= N; ++i) two[i] = two[i-1] * 2;

    mint res = modpow(mint(3), N);
    for (int k = N/2 + 1; k <= N; ++k) {
        res -= bc.com(N, k) * two[N - k] * 2;
    }
    cout << res << endl;
}