けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 141 E - Who Says a Pun? (水色, 500 点)

文字列検索に関するライブラリが充実していれば怖いものがない。でも文字列のことを知らなくても実は DP でも解ける!!!

  • Suffix Array
  • Z-algorithm (editorial 解)
  • ロリハ + 二分探索
  • 「ロリハ + 二分探索」の高速化 (editorial のラスト 3 行で言及された別解)
  • DP

の五通りの方法でやってみる

問題へのリンク

問題概要

長さ  N の文字列  S があたえれる。
 S の連続する部分文字列として、重ならずに 2 回以上現れるもののうち、最長のものの長さを答えてください。

制約

  •  2 \le N \le 5000

解法 1:Suffix Array の LCP 配列

まず、蟻本の P.340 に書いてある方法。
ここで文字列 S の i 文字目から先を取り出した部分文字列を S[ i : ] と書くことにする。 さて、Suffix Array で何ができるのかだけ書くと

  •  O(N (\log{N})^{2}) で Suffix Array を構成する ( O(N) の方法もある)
  • それによって求められた LCP 配列と呼ばれるものを Sparse Table に載せる ( O(N \log{N}) かかる)
  • 以上の前処理を行っておくと、S[ i : ] と S[ j : ] とが、先頭から何文字まで共通しているか (LCP とよぶ) を、 O(1) で判定することができる

ということになる (ここまですべて蟻本参照、ただし蟻本では Sparse Table のところをセグメントツリーでやっている)。この方法を使えば

  • 各 i < j に対して、
  • LCP の値 lcp を求めて
  • そのままだと部分文字列がかぶることもある ("ababa" で i = 0, j = 2 のとき LCP は "aba" で長さ 3) ので、lcp が j - i と比べて大きかったら lcp = j - i とする

というのを各 i < j で全探索して最大値を求めれば OK。計算量は  O(N^{2})

以下のコードで、main 関数はとても短くて「ライブラリがあれば貼るだけ」を強く実感できる!!!

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;


// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
    void init(const vector<MeetSemiLattice> &vec) {
        int n = (int)vec.size(), h = 0;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};


// Suffix Array ( Manber&Myers: O(n (logn)^2) )
struct SuffixArray {
    string str;
    vector<int> sa;         // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> lcp;        // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    inline int& operator [] (int i) {return sa[i];}
    
    SuffixArray(const string& str_) : str(str_) { buildSA(); calcLCP(); }
    void init(const string& str_) { str = str_; buildSA(); calcLCP(); }
    
    // build SA
    vector<int> rank_sa, tmp_rank_sa;
    struct CompareSA {
        int n, k;
        const vector<int> &rank;
        CompareSA(int n, int k, const vector<int> &rank_sa) : n(n), k(k), rank(rank_sa) {}
        bool operator()(int i, int j) {
            if (rank[i] != rank[j]) return (rank[i] < rank[j]);
            else {
                int rank_ik = (i + k <= n ? rank[i + k] : -1);
                int rank_jk = (j + k <= n ? rank[j + k] : -1);
                return (rank_ik < rank_jk);
            }
        }
    };
    void buildSA() {
        int n = (int)str.size();
        sa.resize(n+1), lcp.resize(n+1), rank_sa.resize(n+1), tmp_rank_sa.resize(n+1);
        for (int i = 0; i < n; ++i) sa[i] = i, rank_sa[i] = (int)str[i];
        sa[n] = n, rank_sa[n] = -1;
        for (int k = 1; k <= n; k *= 2) {
            CompareSA csa(n, k, rank_sa);
            sort(sa.begin(), sa.end(), csa);
            tmp_rank_sa[sa[0]] = 0;
            for (int i = 1; i <= n; ++i) {
                tmp_rank_sa[sa[i]] = tmp_rank_sa[sa[i - 1]];
                if (csa(sa[i - 1], sa[i])) ++tmp_rank_sa[sa[i]];
            }
            for (int i = 0; i <= n; ++i) rank_sa[i] = tmp_rank_sa[i];
        }
    }
    vector<int> rsa;
    SparseTable<int> st;
    void calcLCP() {
        int n = (int)str.size();
        rsa.resize(n+1);
        for (int i = 0; i <= n; ++i) rsa[sa[i]] = i;
        lcp.resize(n+1);
        lcp[0] = 0;
        int cur = 0;
        for (int i = 0; i < n; ++i) {
            int pi = sa[rsa[i] - 1];
            if (cur > 0) --cur;
            for (; pi + cur < n && i + cur < n; ++cur) {
                if (str[pi + cur] != str[i + cur]) break;
            }
            lcp[rsa[i] - 1] = cur;
        }
        st.init(lcp);
    }
    
    // calc lcp
    int getLCP(int a, int b) {          // lcp of str.sutstr(a) and str.substr(b)
        return st.get(min(rsa[a], rsa[b]), max(rsa[a], rsa[b]));
    }
};

int main() {
    int N;
    string S;
    cin >> N >> S;

    // Suffix Array 構築
    SuffixArray SA(S);

    // 全探索
    int res = 0;
    for (int i = 0; i < N; ++i) {
        for (int j = i+1; j < N; ++j) {
            int lcp = SA.getLCP(i, j);
            lcp = min(lcp, j-i);
            res = max(res, lcp);
        }
    }
    cout << res << endl;
}

解法 2:Z-algorithm

ほとんど同じ思想の解法で、Z-algorithm は

  •  O(N) の前処理で配列 lcp を求める
  • lcp[ i ] := S 自身と S[ i : ] とが、先頭から最大で何文字一致しているかを  O(1) で求めることができる

というもの。Suffix Array の LCP 配列は LCP(i, j) を求めることができたのに対して、Z-algorithm では LCP(0, i) の形しか求めることができない。しかしその対策は簡単で

  • 各 i に対して  S の i 文字目以降の部分文字列を T とする
  • T について Z-algorithm を適用する

という風にすれば OK。計算量はやはり  O(N^{2})

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
vector<int> Zalgo(const string &S) {
    int N = (int)S.size();
    vector<int> res(N);
    res[0] = N;
    int i = 1, j = 0;
    while (i < N) {
        while (i+j < N && S[j] == S[i+j]) ++j;
        res[i] = j;
        if (j == 0) {++i; continue;}
        int k = 1;
        while (i+k < N && k+res[k] < j) res[i+k] = res[k], ++k;
        i += k, j -= k;
    }
    return res;
}

int main() { 
    int N;
    string S;
    cin >> N >> S;

    int res = 0;
    for (int i = 0; i < N; ++i) {
        string T = S.substr(i);
        auto lcp = Zalgo(T);

        for (int j = 0; j < T.size(); ++j) {
            int l = min(lcp[j], j);
            res = max(res, l);
        }
    }
        
    cout << res << endl;
}

解法 3:ロリハ + 二分探索

実は LCP はローリングハッシュ + 二分探索でも求められるという話は結構有名だったりする。ローリングハッシュは

  •  O(N) の前処理を行っておくことで
  • 文字列 S の区間 [i, j) の部分文字列に関するハッシュ値を  O(1) で返す

というもの。発想は累積和とものすごく近い!!!

さて、これができれば

  • 各 i < j に対して
  • 「i 文字目から始めて x 文字分とったもの」と「j 文字目から始めて x 文字分とったもの」のハッシュ値が一致するような最大の x を求める (by 二分探索)

という方法によって、 O(\log{N}) で LCP(i, j) を求めることができる。計算量は  O(N^{2}\log{N}) となる。このままではかなり定数倍が厳しい。以下のコードは 1735ms だった。

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;


// ローリングハッシュ
// 二分探索で LCP を求める機能つき
struct RollingHash {
    static const int base1 = 1007, base2 = 2009;
    static const int mod1 = 1000000007, mod2 = 1000000009;
    vector<long long> hash1, hash2, power1, power2;

    // construct
    RollingHash(const string &S) {
        int n = (int)S.size();
        hash1.assign(n+1, 0);
        hash2.assign(n+1, 0);
        power1.assign(n+1, 1);
        power2.assign(n+1, 1);
        for (int i = 0; i < n; ++i) {
            hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
            hash2[i+1] = (hash2[i] * base2 + S[i]) % mod2;
            power1[i+1] = (power1[i] * base1) % mod1;
            power2[i+1] = (power2[i] * base2) % mod2;
        }
    }
    
    // get hash of S[left:right]
    inline pair<long long, long long> get(int l, int r) const {
        long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
        if (res1 < 0) res1 += mod1;
        long long res2 = hash2[r] - hash2[l] * power2[r-l] % mod2;
        if (res2 < 0) res2 += mod2;
        return {res1, res2};
    }

    // get lcp of S[a:] and T[b:]
    inline int getLCP(int a, int b) const {
        int len = min((int)hash1.size()-a, (int)hash1.size()-b);
        int low = 0, high = len;
        while (high - low > 1) {
            int mid = (low + high) >> 1;
            if (get(a, a+mid) != get(b, b+mid)) high = mid;
            else low = mid;
        }
        return low;
    }
};

int main() { 
    int N;
    string S;
    cin >> N >> S;

    // ロリハ
    RollingHash rh(S);

    // 求める
    int res = 0;
    for (int i = 0; i < N; ++i) {
        for (int j = i+1; j < N; ++j) {
            int lcp = rh.getLCP(i, j);
            lcp = min(lcp, j-i);
            res = max(res, lcp);
        }
    }   
        
    cout << res << endl;
}

解法 4:ロリハにぶたん解の高速化

ロリハ二分探索解は単純にやると定数倍がかなり辛いので、ちょっとした高速化ができる。

  • 各 (i, j) に対して二分探索して LCP を求める

という風にする代わりに

  • 長さ m を予め固定して、S.substr(i, m) == S.substr(j, m) となるような (i, j) が存在するかどうかを判定して二分探索

という感じにする。各 i に対して区間 [i, i + m) のハッシュ値を求めて、それらに重複があるかどうか (ただしそれらの index が m 以上離れている必要がある) を判定するだけになるので、かなり早くなる。  O(N(\log{N})^{2}) 程度。この工夫で、1735ms -> 13 ms になった。

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
using namespace std;

// ローリングハッシュ
struct RollingHash {
    static const int base1 = 1007, base2 = 2009;
    static const int mod1 = 1000000007, mod2 = 1000000009;
    vector<long long> hash1, hash2, power1, power2;

    // construct
    RollingHash(const string &S) {
        int n = (int)S.size();
        hash1.assign(n+1, 0);
        hash2.assign(n+1, 0);
        power1.assign(n+1, 1);
        power2.assign(n+1, 1);
        for (int i = 0; i < n; ++i) {
            hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
            hash2[i+1] = (hash2[i] * base2 + S[i]) % mod2;
            power1[i+1] = (power1[i] * base1) % mod1;
            power2[i+1] = (power2[i] * base2) % mod2;
        }
    }
    
    // get hash of S[left:right]
    inline pair<long long, long long> get(int l, int r) const {
        long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
        if (res1 < 0) res1 += mod1;
        long long res2 = hash2[r] - hash2[l] * power2[r-l] % mod2;
        if (res2 < 0) res2 += mod2;
        return {res1, res2};
    }
};

int main() {
    int N;
    string S;
    cin >> N >> S;

    // ロリハ
    RollingHash rh(S);

    // 二分探索判定
    auto check = [&](int d) -> bool {
        map<pair<long long, long long>, int> ma;
        for (int i = 0; i + d <= N; ++i) {
            auto p = rh.get(i, i+d);
            if (ma.count(p)) {
                if (i - ma[p] >= d) return true;
            }
            else ma[p] = i;
        }
        return false;
    };

    // 二分探索
    int left = 0, right = N/2 + 1;
    while (right - left > 1) {
        int mid = (left + right) >> 1;
        if (check(mid)) left = mid;
        else right = mid;
    }
        
    cout << left << endl;
}

解法 5:DP

最後に、実は高度な文字列検索アルゴリズムを持っていなくても解けると。

  • dp[ i ][ j ] := i 文字目からと j 文字目からとで最長の長さ (この時点では文字列がかぶっても OK とする)

とすると

  • S[ i ] != S[ j ] なら dp[ i ][ j ] = 0
  • そうでないなら dp[ i ][ j ] = dp[ i + 1 ][ j + 1 ] + 1

という感じ。更新順に注意する。最後は min(dp[ i ][ j ], j - i) の最大値を出力する。

#include <iostream>
#include <vector>
#include <string>
using namespace std;
void chmax(int &a, int b) { if (a < b) a = b; }

int main() {
    int N; string S; cin >> N >> S;

    int res = 0;
    vector<vector<int> > dp(N+1, vector<int>(N+1, 0));
    for (int i = N-1; i >= 0; --i) {
        for (int j = N-1; j > i; --j) {
            if (S[i] == S[j]) chmax(dp[i][j], dp[i+1][j+1] + 1);
            chmax(res, min(dp[i][j], j-i));
        }
    }
    cout << res << endl;
}