けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

diverta 2019 C - AB Substrings (400 点)

ペアリングを場合分けしてルールベースで頑張る系の問題、過去に何度もやらかしていて苦手意識が強い

問題へのリンク

問題概要

 N 個の文字列  S_1, S_2, \dots, S_N が与えられる。

これらを任意の順序で連結してできる  N! 通りの文字列のうち、その中に "AB" を連続部分列として含んでいる箇所の個数の最大値を求めよ。

制約

  •  1 \le N \le 10^{4}
  •  2 \le |S_i| \le 10

考えたこと

まず、各  S_i の中に含まれる "AB" については、別途カウントしてしまって大丈夫。その後は文字列 S = "...A" と T = "B..." を連結して

  • S + T = "...AB..."

となることで増える "AB" の個数を最大化する問題になると言える。

整理

状況としては

  • "B....A" な文字列が  x
  • "...A" な文字列が  y 個 (初めの文字が B でない)
  • "B..." な文字列が  z 個 (最後の文字が A でない)

あるときに最大で "AB" を何個増やせるかを考える問題になる。丁寧に考える

y = z = 0 のとき

"(B...A)(B...A)(B... ... ... A)(B...A)" を作ることになって、連結部分は  x-1 個になり、これが最大値である。。。

が、少し罠があって  x = 0 のときは  0 通りになる。よって  \max(x-1, 0) 通り (この罠で WA を生やした)

y = z > 0 のとき

とりあえず、

  • "B...A" を  x 個連結して、その両端を "...A" と "B..." を 1 個ずつ消費して  x+1 個の "AB" を作る
  • 残った "...A" と "B..." で、 \min(y, z) - 1 個作る

ということで合計で、 x + \min(y, z) 個作ることができる。そしてよく考えると、

  • 'A' は  x + y
  • 'B' は  x + z

しかないので、どんなに頑張っても  x + \min(y, z) 個までしか作ることができない。よって最大値は  x + \min(y, z) 個で確定する

y > z のとき

  • "B...A" を  x 個連結して、左端を "...A" を 1 個消費して埋めることで  x
  • "...A" が y-1 個、"B..." が z 個から  x + z 個作ることができる。

そして  x + z = x + \min(y, z) であることに注意すると、上と同様にこれが最大で確定する。

y < z のとき

同様に  x + \min(y, z) が最大で確定する。

まとめ

  •  y = z = 0 のとき、 \max(0, x-1)
  • それ以外のとき、 x + \min(y, z)
#include <iostream>
#include <vector>
#include <string>
using namespace std;

int count(const string &S) {
    int res = 0;
    for (int i = 0; i+1 < S.size(); ++i) {
        if (S[i] == 'A' && S[i+1] == 'B') ++res;
    }
    return res;
}

int N;
vector<string> S;

long long solve() {
    long long res = 0;
    for (int i = 0; i < N; ++i) res += count(S[i]);

    int a = 0, b = 0, c = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i][0] == 'B' && S[i].back() == 'A') ++a;
        else if (S[i].back() == 'A') ++b;
        else if (S[i][0] == 'B') ++c;
    }
    long long add = 0;
    if (b + c == 0) add = max(0, a-1);
    else add = a + min(b, c);

    res += add;
    return res;
}

int main() {
    while (cin >> N) {
        S.resize(N);
        for (int i = 0; i < N; ++i) cin >> S[i];
        cout << solve() << endl;
    }
}