けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 110 E - Shorten ABC (赤色, 800 点)

コンテスト中に間に合わなかった。あと、僕の考察は XOR の言葉ではなかったけど、よく考えたら XOR と等価だった。

問題概要

"A", "B", "C" のみからなる長さ  N の文字列  S が与えられる。この文字列に以下の操作を好きな順序で好きな回数だけ行える。操作によって作ることのできる文字列として考えられるものの個数を 1000000007 で割ったあまりを求めよ。

  •  S 中の連続する 2 文字であって、同文字でないものを選ぶ
  • その 2 文字を削除して、代わりに「その 2 文字のいずれでもない文字」を挿入する

(たとえば "AABBCC" -> "ACBCC")

制約

  •  1 \le N \le 10^{6}

考えたこと

まずは例によって「どんな文字列なら作れるのか」を考察することにした。いくつか試して思ったのは

「ある文字列が操作によって最終的に 1 文字になるならば、その文字はマージ順序によらない」

ということだった。より具体的には、"A", "B", "C" の個数をそれぞれ  p, q, r としたとき

  •  p, q, r がすべて偶数のとき、1 文字にはできない (最後が "AA" や "CC" などになる)
  •  p, q, r がすべて奇数のとき、1 文字にはできない (最後が "AA" や "CC" などになる)
  •  p, q, r のうち奇数が 1 個だけの場合は、その文字が生き残る (たとえば "AAAABCC" なら "B" が生き残る)
  •  p, q, r のうち偶数が 1 個だけの場合は、その文字が生き残る (たとえば "AAABBC" なら "B" が生き残る)

ということがわかった。以上の事柄を整理すると、「隣接する 2 文字のマージ」は「演算」として表せるのだろうと考えた。

  • "AA" のように消せないものを表す数値を便宜的に 0
  • "A" を 1
  • "B" を 2
  • "C" を 3

と表すことにする。上記の考察は、この演算が結合法則を満たすということを意味する (実際は、交換法則も満たしているし、逆元も存在する)。具体的には、次のような関数で表現できることがわかった (なぜここまで導いていて、これを XOR だと思わなかったのかは謎)。

    auto calc = [&](int a, int b) -> int {
        if (a == 0) return b;
        if (b == 0) return a;
        if (a == b) return 0;
        return 6 - a - b;
    };

対象を一意に定める操作列を導く

最初は  S を区間ごとに分割して、それぞれの区間の結果が 0 以外 (1, 2, 3) になるようにしたものを数えればよいのかなと思った。しかしそれだとサンプル 1 ですでに詰まってしまう。

| 1 | 211 | 3 |
| 1 | 2 | 113 |

とは同じく "123" という結果になる。よって作戦を変更する。ここで部分列 DP でも見られるような「対象物を一意に定めるような操作列」に対応させるアイディアを考えた。そのような操作列は Greedy 手法によるところが大きい。

たとえば、 S から  T を作れるかどうかを判定する問題は次のように解くことができる。


  •  T の文字を左から見ていって (v とする)、
  •  S の残った文字のうち、 S の左から i 文字分の結合結果が v に一致するような最小の i をとる
  •  S の左から i 文字分を削除する

この作業を  T のすべての文字について実施し終えたあと、

  •  S が空文字列になる
  •  S の残った文字列を結合した結果が 0 になる ("AA" とかが残るようなものになる)

のいずれかの条件を満たすとき、 T が作れる。


この判定条件は、そのまま「対象物を一意に定める操作列」となっているので DP できる。

  • dp[ i ] :=  S の最初の i 文字分から作れる文字列の総数

そして、S の i 文字目から j 文字分の結合結果が v (= 1, 2, 3) となる最小の j をそれぞれ nex[ i ][ v ] とすると

dp[ i + nex[ i ][ v ] ] += dp[ i ]

と遷移できる。答えは、S の i 文字目以降の結合結果が 0 となるような各 i に対する dp[ i ] の総和となる。計算量は、nex[ i ][ v ] を上手に求めておくことで  O(N) となる。

コード

実装上は、nex[ i ][ v ] を求めるのに苦労した。愚直にやると

S = "ABAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA...ABC"

のように「同じ文字が連続する箇所」が厄介で、それを高速に抜けないといけないので  O(N^{2}) の計算量となってしまう。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmin(T& a, T b) { 
    if (a > b) { a = b; return 1; } 
    return 0; 
}

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

mint solve(string S) {
    auto calc = [&](int a, int b) -> int {
        if (a == 0) return b;
        if (b == 0) return a;
        if (a == b) return 0;
        return 6 - a - b;
    };
    int N = S.size();
    vector<int> T(N), ST(N+1, 0), ran(N, N);
    for (int i = 0; i < N; ++i) T[i] = S[i]-'A'+1, ST[i+1] = calc(ST[i], T[i]);
    for (int i = 0; i < N;) {
        int j = i;
        while (j < N && T[j] == T[i]) ++j;
        for (int k = i; k < j; ++k) ran[k] = j;
        i = j;
    }
    if (ran[0] == N) return mint(1);
    vector<vector<int>> nex(N, vector<int>(4, N+1));
    for (int i = 0; i < N; ++i) {
        int cur = T[i];
        for (int j = i+1; j <= N;) {
            chmin(nex[i][cur], j);
            if (j < N && T[j-1] == T[j]) {
                if ((ran[j] - j) & 1) cur = calc(cur, T[j]);
                j = ran[j];
            }
            else cur = calc(cur, T[j++]);
            if (nex[i][1] <= N && nex[i][2] <= N  && nex[i][3] <= N) break;
        }
    }

    vector<mint> dp(N+1, 0);
    dp[0] = 1;
    for (int i = 0; i < N; ++i) {
        for (int j = 1; j <= 3; ++j) if (nex[i][j] <= N) dp[nex[i][j]] += dp[i];
        if (i > 0 && ST[i] == ST.back()) dp[N] += dp[i];
    }
    return dp[N];
}

int main() {
    int N;
    string S;
    while (cin >> N >>  S) cout << solve(S) << endl;
}