けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 324 E - Joint Two Strings (緑色, 500 点)

問題を典型パーツに分解して考察を積み上げていく系の、とても教育的な問題!

問題概要

 N 個の文字列  S_{1}, S_{2}, \dots, S_{N} と文字列  T が与えられる。以下の条件を満たす  (i, j) の組の個数を求めよ。

  •  S_{i} S_{j} をこの順に連結してできる文字列から、いくつかの文字を選び、順序を保って連結すると文字列  T に一致する

制約

  •  N \le 5 \times 10^{5}
  •  S_{i} の長さの総和は  10^{6} 以下

考えたこと

この手の問題では、考察を一歩一歩積み上げることが肝要といえる。まず、次の問題が解ける必要があることに気づく


2 つの文字列  S, T が与えられる。 S からいくつかの文字を選び、順序を保って連結すると文字列  T に一致させられるかを判定せよ


実はこの問題にはよく知られた Greedy 解法がある。

 T の「次に来る文字」を表す window を持っておき、 S を前から順にみていって、window の文字に一致したら、window を  T の次の文字にセットしなおすようなことをすればよい。そして、 S が最後まで行き着く前に、window が  T の最後の文字までクリアすれば "Yes"、クリアできなければ "No" だ。

こんな感じに実装できる。

bool ok = false;
int iter = 0;  // T の待ち文字を表す window
for (int i = 0; i < S.size(); ++i) {
    if (S[i] == T[iter]) ++iter;

    // T のラス文字までクリアしたら
    if (iter == T.size()) {
        ok = true;
        break;
    }
}

今回の問題へ

同様に、2 つの文字列  S_{i} + S_{j} の部分列で  T を作れるかどうかは次のように判定できる。

  •  S_{i} T を前から見ていったときに、 T の前から  l_{i} 文字分までで一致させられるとする
  •  S_{j} T を後ろから見ていったときに、 T の後ろから  r_{j} 文字分まで一致させられるとする
  • このとき、 l_{i} + r_{j} T の長さ ( M とする) 以上であれば、Yes と判定できる

これを踏まえると、もとの問題は次のように言い換えられる。


2 つのサイズ  N の数列  l_{1}, l_{2}, \dots, l_{N} r_{1}, r_{2}, \dots, r_{N} が与えられる。

 l_{i} + r_{j} \ge M であるような組  (i, j) の個数を求めよ。


ここまで来れば、いつもの二分探索問題となる。 r と小さい順にソートしておくことで、各  i = 1, 2, \dots, N に対して、 r_{j} \ge M - l_{i} となる  j の個数は lower_bound() で求められる。

計算量は  O(\sum_{i=1}^{N}|S_{i}| + N \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    string T;
    cin >> N >> T;
    
    vector<int> left(N, 0), right(N, 0);
    for (int i = 0; i < N; ++i) {
        string S;
        cin >> S;
        for (int j = 0; j < S.size(); ++j) {
            if (left[i] < T.size() && S[j] == T[left[i]]) {
                ++left[i];
            }
            if (right[i] < T.size() && S[S.size()-1-j] == T[T.size()-1-right[i]]) {
                ++right[i];
            }
        }
    }
    
    int M = T.size();
        
    // left[i] + right[j] >= M となる (i, j) の個数
    long long res = 0;
    sort(right.begin(), right.end());
    for (int i = 0; i < N; ++i) {
        auto it = lower_bound(right.begin(), right.end(), M-left[i]) - right.begin();
        long long tmp = N - it;
        res += tmp;
    }
    cout << res << endl;
}