けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 027 E - ABBreviate (1300 点)

問題へのリンク

問題概要 (AGC 027 E)

'a' と 'b' からなる文字列 S が与えられる。

  • 連続する "aa" を "b" に置き換える
  • 連続する "bb" を "a" に置き換える

という操作を好きな回数だけしてできる文字列 T として考えられるものを数え上げよ。

制約

  •  1 \le |S| \le 10^{5}

考えたこと

似た問題は ARC 094 F - Normalization で見た気がする。

"a" を 1、"b" を 2 に対応させると、問題の操作は、「数値の総和を 3 で割った余り」について不変である。

ここで元の文字列を区間分割して、各区間ごとに、T の各文字ごとに対応するようにしてみる。最終的には「異なる区間の分け方が同じ T になる可能性」についてきちんと考えないといけないが、一旦無視して、

  • どのような区間が、"a" や "b" になるのか

をまずはきちんと考察する。少し考えると

  • 3 で割り切れるときはダメ
  • ababababa... みたいに同じ文字が連続している箇所がなかったらダメ
  • そうでないときは、
    • 3 で割ったあまりが 1 のときは "a" につぶせる ("b" には絶対にならない)
    • 3 で割ったあまりが 0 のときは "b" につぶせる ("a" には絶対にならない)

という風になることがわかる。つまり、文字列 T にすることが可能かどうかは、

  • S を区間分割して、各区間ごとに T の対応した文字を作れるようにできるか

という方法で可能なことがわかる。そのような T を数え上げる。ここで問題になるのは、異なる区間分割方法が同じ T を導くことがあること。でもこういうシチュエーションは


文字列 S が与えられて、その部分文字列としてありうるものを数え上げよ


という問題でも出現する。例えば "abcbd" の部分文字列として "abd" があるが、真ん中の "b" の選び方は 2 通りある。こういう問題の対処法として、

  • 異なる選び方で同じ文字列が作られるのなら、最左の選び方を採用する

というのがある:


  • dp[ i ] := i 文字目を最後に選んだ場合の場合の数
  • next[ i ][ c ] := i + 1 文字目以降で最初に文字 c が登場する index

として、各 c に対して

dp[ next[ i ][ c ] ] += dp[ i ]

とする


という感じ。これの真似をすれば OK!

#include <iostream>
#include <string>
#include <vector>
using namespace std;

const int MOD = 1000000007;
string s;

void add(long long &a, long long b) {
    a += b;
    if (a >= MOD) a -= MOD;
}

int main() {
    cin >> s;
    int n = (int)s.size();
    
    // 累積和
    vector<int> sum(n+1, 0);
    for (int i = 0; i < n; ++i) {
        int par = (s[i] == 'a' ? 1 : 2);
        sum[i+1] = (sum[i] + par) % 3;
    }
    
    // 次の累積和が「0」「1」「2」になる瞬間の切れ目
    vector<vector<int> > next(n+1, vector<int>(3, n+1));
    for (int i = n-1; i >= 0; --i) {
        next[i][sum[i+1]] = i+1;
        for (int j = 0; j < 3; ++j) next[i][j] = min(next[i][j], next[i+1][j]);
    }
    
    // 次の「同じ文字 2 連続」を含む切れ目
    vector<int> next2(n+1, n);
    for (int i = n-2; i >= 0; --i) {
        if (s[i] == s[i+1]) next2[i] = i+1;
        else next2[i] = min(next2[i], next2[i+1]);
    }
    
    // DP
    vector<long long> dp(n+1, 0);
    dp[0] = 1;
    for (int i = 0; i < n; ++i) {
        add(dp[i+1], dp[i]);
        int ne = next[next2[i]][(sum[i]+(s[i]=='a'?2:1))%3];
        if (ne <= n) add(dp[ne], dp[i]);
    }
    
    // res
    long long res = 0;
    if (next2[0] == n) res = 1;
    else {
        for (int i = 1; i <= n; ++i) {
            if (sum[i] == sum[n]) add(res, dp[i]);
        }
    }
    
    cout << res << endl;
}