けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder Library Practice Contest I - Number of Substrings

Suffix Array と LCP の理解を問う問題。超シンプルで面白い問題!

問題概要

長さが  N の文字列  S が与えられる。

 S の連続する部分文字列の種類数を答えよ。

制約

  •  1 \le N \le 5 \times 10^{5}

解法

文字列  S の Suffix とは「後ろの何文字かをとってできる文字列」のことである。長さ  N の文字列の Suffix は  N+1 種類ある (末尾から末尾という空文字列も含む)。

Suffix Array とは、これら  N+1 個の Suffix を辞書順にソートしたものだ。たとえば、 S = poporinri (9 文字) の場合、Suffix Array は次のようになる。

0: (9 文字目以降、空文字列)
1:i (8 文字目以降)
2:inri (5 文字目以降)
3:nri (6 文字目以降)
4:oporinri (1 文字目以降)
5:orinri (3 文字目以降)
6:poporinri (0 文字目以降)
7:porinri (2 文字目以降)
8:ri (7 文字目以降)
9:rinri (4 文字目以降)

この場合、求められる配列 sa は次のようになる。

sa = [9, 8, 5, 6, 1, 3, 0, 2, 7, 4]

高さ配列 LCP

Suffix Array を求めるメリットの一つは、次の高さ配列 lcp (長さ  N) が求められることだ。


  • lcp[i] ← 文字列  S の suffix を辞書順で小さい順に並べたときの、 i 番目のものと  i+1 番目のものについて、先頭が何文字まで一致しているか

たとえば先ほどの文字列  S = poporinri を例にとる。以下の Suffix Array の隣接箇所について、先頭何文字目まで一致しているかを数えて、次のように求められる。

0: (9 文字目以降、空文字列)
1:i (8 文字目以降)
2:inri (5 文字目以降)
3:nri (6 文字目以降)
4:oporinri (1 文字目以降)
5:orinri (3 文字目以降)
6:poporinri (0 文字目以降)
7:porinri (2 文字目以降)
8:ri (7 文字目以降)
9:rinri (4 文字目以降)

lcp = [0, 1, 0, 0, 1, 0, 2, 0, 2]

となる。

さて、部分文字列の個数を数え上げよう。上記の各 suffix に対して、prefix をとったものの集合の要素数が答えとなる。重複をいかに除去するかが問題だ。

たとえば、sa[6] を表す文字列 "poporinri" と sa[7] を表す文字列 "porinri" は、ともに prefix として "p" と "po" を含んでいて重複してしまう。この重複を除去するのに高さ配列 lcp が役に立つ。

具体的には「Suffix Array の順序で後ろの方の文字列の prefix にも登場する文字列は除外して数えていく」というようにできる。上の例では次表のようになる。この「重複分」の値が lcp[i] に他ならない。

Suffix Array での順序 suffix 文字列 重複分 重複分を除外した個数
0 0 0 - 0 = 0
1 i 1 1 - 1 = 0
2 inri 0 4 - 0 = 4
3 nri 0 3 - 0 = 3
4 oporinri 1 8 - 1 = 7
5 orinri 0 6 - 0 = 6
6 poporinri 2 9 - 2 = 7
7 porinri 0 7 - 0 = 7
8 ri 2 2 - 2 = 0
9 rinri 0 5 - 0 = 5

右端の値の総和が答えとなる。

さらに整理して考えると、重複分を除外する前の部分文字列の個数は  \frac{N(N+1)}{2} となる。重複分の総和は高さ配列 lcp の総和となる。よって求める答えは

 \displaystyle \frac{N(N+1)}{2} - \sum_{i = 0}^{N-1} lcp[i]

と簡潔に表せる。

Suffix Array や高さ配列 LCP は、 O(N) で求められるので、この解法の計算量も  O(N) と評価できる。

コード

Suffix Array や LCP を求めるアルゴリズムは、ACL では、atcoder::string で実装されている。

#include <bits/stdc++.h>
#include <atcoder/string>
using namespace std;

int main() {
    string S;
    cin >> S;
    long long N = S.size();

    // Suffix Array の構築
    vector<int> sa = atcoder::suffix_array(S);
    vector<int> lcp = atcoder::lcp_array(S, sa);

    // 高さ配列の総和を求める
    long long sum = 0;
    for (auto v : lcp) sum += v;

    // 答え
    cout << N*(N+1)/2 - sum << endl;
}

自前ライブラリでも AC

#include <bits/stdc++.h>
using namespace std;

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;   // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> lcp;  // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    int& operator [] (int i) {
        return sa[i];
    }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        int N = (int)str.size();
        vector<int> s;
        for (int i = 0; i < N; ++i) s.push_back(str[i] + 1);
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // prepair lcp
    vector<int> rsa;
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rsa.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rsa[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rsa[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rsa[i] - 1] = h;
        }
    }
};

int main() {
    string S;
    cin >> S;
    long long N = S.size();

    // Suffix Array の構築
    SuffixArray<string> sa(S);

    // 高さ配列の総和を求める
    long long sum = 0;
    for (auto v : sa.lcp) sum += v;

    // 答え
    cout << N*(N+1)/2 - sum << endl;
}