けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 060 F - 最良表現 (赤色, 900 点)

今回は Suffix Array でやってみたけど、ローリングハッシュとか、KMP とか、Z-Algorithm とか、色んな方法があるみたいなので追々やってみたい。 → やってみた (3/14)

問題へのリンク

問題概要

文字列  x がよい文字列であるとは「いかなる文字列  y および 2 以上の整数  k に対しても、  y k 回繰り返した文字列が  x と異なることをいう。

また文字列の列  (f_1, f_2, ..., f_m)

  •  f_1 + f_2 + \dots + f_m =  S ('+' は文字列の concat)
  •  f_i はすべてよい文字列

を満たしているとき、これを文字列  S の「よい表現」であるという。

文字列  S が与えられて、 S のよい表現として考えられるものとして最も項数が短いものを求め、さらにその項数をもつよい表現が何通りあるかを数え上げよ。

制約

  •  1 \le |S| \le 5 × 10^{5}

解法 1: Suffix Array

 n = |S| とする。

  •  S 自体がよい文字列のとき、明らかに (1, 1 通り)
  •  S を構成する文字がすべて同一のとき、明らかに ( n, 1通り)

それ以外の場合を考える。 S 自体がよい文字列でないということは、 S

abcabcabcabc

みたいな、とある文字列 (長さ  p とする) の繰り返しになっているはずである。そこで、これを

a bcabcabcabc

みたいに先頭 1 文字だけ分けると、各々「よい文字列」になっていることが色々考えるとわかる (もしどちらも「よい文字列」でないとすると、aaaaaaa のような全部同じ文字で構成されていなければならないことが示せる)。

よって、最良表現の項数が 2 であることがわかった。あとは個数を求めるために、

  • S = S[0 : i] + S[i : n]

と区切ったときに、どちらかが「よくない文字列」になる場所をすべて求めて引けば良い。

S[0 : i] がよくない文字列であるとき、i の約数 d が存在して、S[0 : i] は S[0 : d] の繰り返しになっている。繰り返しになっているかどうかは、

  • S[0 : ] と S[d : ] の共通部分文字列が i - d 以上 (これは例えば SuffixArray の LCP 配列を用いて求められる)

によって判定できる。S[i : n] についても同様にできる。

計算量は、各 d = 1, 2, ..., n について、d の倍数を 1〜n まで見ていく感じになるので、

  •  n + n/2 + n/3 + ... + n/n ~  O(n \log{n})

になる。

巧妙な枝刈りで O(n) に

 O(n\log{n}) で十分通るのだが、S[0 : i] が良くない場所を列挙しているときに

  • d 箇所目で切れないことがわかっている場合には、d の倍数を見ていくことはしない (既に d の約数 e が存在して、e の倍数の箇所は調べ上げられているため)

とすると、各切り目につき高々 2 回しか check されないので、全体の計算量が  O(n) まで改善される。Suffix Array を SA-IS を用いると  O(n) で構築できるため、SA 構築を含めても  O(n) になる (...が、ここでは普通の SA を用いた)。

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;

// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    const MeetSemiLattice INF = 1<<30;
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<int> &vec) { init(vec); }
    void init(const vector<int> &vec) {
        int n = (int)vec.size(), h = 0;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        if (a >= b) return INF;
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};

// Suffix Array ( Manber&Myers: O(n (logn)^2) )
struct SuffixArray {
    string str;
    vector<int> sa;         // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> lcp;        // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    inline int& operator [] (int i) {return sa[i];}
    
    SuffixArray(const string& str_) : str(str_) { buildSA(); calcLCP(); }
    void init(const string& str_) { str = str_; buildSA(); calcLCP(); }
    
    // build SA
    vector<int> rank_sa, tmp_rank_sa;
    struct CompareSA {
        int n, k;
        const vector<int> &rank;
        CompareSA(int n, int k, const vector<int> &rank_sa) : n(n), k(k), rank(rank_sa) {}
        bool operator()(int i, int j) {
            if (rank[i] != rank[j]) return (rank[i] < rank[j]);
            else {
                int rank_ik = (i + k <= n ? rank[i + k] : -1);
                int rank_jk = (j + k <= n ? rank[j + k] : -1);
                return (rank_ik < rank_jk);
            }
        }
    };
    void buildSA() {
        int n = (int)str.size();
        sa.resize(n+1), lcp.resize(n+1), rank_sa.resize(n+1), tmp_rank_sa.resize(n+1);
        for (int i = 0; i < n; ++i) sa[i] = i, rank_sa[i] = (int)str[i];
        sa[n] = n, rank_sa[n] = -1;
        for (int k = 1; k <= n; k *= 2) {
            CompareSA csa(n, k, rank_sa);
            sort(sa.begin(), sa.end(), csa);
            tmp_rank_sa[sa[0]] = 0;
            for (int i = 1; i <= n; ++i) {
                tmp_rank_sa[sa[i]] = tmp_rank_sa[sa[i - 1]];
                if (csa(sa[i - 1], sa[i])) ++tmp_rank_sa[sa[i]];
            }
            for (int i = 0; i <= n; ++i) rank_sa[i] = tmp_rank_sa[i];
        }
    }
    
    // build LCP
    vector<int> rsa;
    SparseTable<int> st;
    void calcLCP() {
        int n = (int)str.size();
        rsa.resize(n+1);
        for (int i = 0; i <= n; ++i) rsa[sa[i]] = i;
        lcp.resize(n+1);
        lcp[0] = 0;
        int cur = 0;
        for (int i = 0; i < n; ++i) {
            int pi = sa[rsa[i] - 1];
            if (cur > 0) --cur;
            for (; pi + cur < n && i + cur < n; ++cur) {
                if (str[pi + cur] != str[i + cur]) break;
            }
            lcp[rsa[i] - 1] = cur;
        }
        st.init(lcp);
    }
    int getLCP(int a, int b) {          // lcp of str.sutstr(a) and str.substr(b)
        return st.get(min(rsa[a], rsa[b]), max(rsa[a], rsa[b]));
    }
};



vector<long long> divisor(long long n) {
    vector<long long> res;
    for (long long i = 1LL; i*i <= n; ++i) {
        if (n%i == 0LL) {
            res.push_back(i);
            long long temp = n/i;
            if (i != temp) res.push_back(temp);
        }
    }
    sort(res.begin(), res.end());
    return res;
}

int main() {
    string str; cin >> str;
    int n = (int)str.size();
    vector<long long> divs = divisor(n);
    long long syuuki = n;
    for (auto d : divs) {
        bool ok = true;
        for (int j = 0; j + d < n; ++j) {
            if (str[j] != str[j+d]) ok = false;
        }
        if (ok) syuuki = min(syuuki, d);
    }
    if (syuuki == n) cout << 1 << endl << 1 << endl;
    else if (syuuki == 1) cout << n << endl << 1 << endl;
    else {
        vector<int> cannot_cut(n*2, 0);
        string str2 = str;
        reverse(str2.begin(), str2.end());
        SuffixArray sa1(str);
        SuffixArray sa2(str2);
        for (int d = 1; d < n; ++d) {
            if (cannot_cut[d]) continue;
            for (int dd = d*2; dd < n; dd += d) {
                if (sa1.getLCP(0, d) >= dd - d) cannot_cut[dd] = true;
                if (sa2.getLCP(0, d) >= dd - d) cannot_cut[n-dd] = true;
            }
        }
        int con = 0;
        for (int i = 1; i < n; ++i) if (!cannot_cut[i]) ++con;
        cout << 2 << endl << con << endl;
    }
}

解法 2: ローリングハッシュ

ローリングハッシュでもできる。ロリハを用いる場合、LCP 計算は通常内部的に二分探索を行うが、それをしなくても直接

            for (int dd = d; dd < n; dd += d) {
                if (rh.get(0, d) != rh.get(dd, dd+d)) break;
                cannot_cut[dd + d] = true;
            }

という感じで、長さ d の区間をとってハッシュ値が一致するかどうかを見ればよい。

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;

struct RollingHash {
    const int base = 9973;
    const vector<int> mod = {999999937LL, 1000000007LL};
    string S_;
    vector<long long> hash[2], power[2];
    RollingHash(){}
    RollingHash(const string &S) : S_(S) {
        int n = (int)S.size();
        for (int iter = 0; iter < 2; ++iter) {
            hash[iter].assign(n+1, 0);
            power[iter].assign(n+1, 1);
            for (int i = 0; i < n; ++i) {
                hash[iter][i+1] = (hash[iter][i] * base + S[i]) % mod[iter];
                power[iter][i+1] = power[iter][i] * base % mod[iter];
            }
        }
    }
    // get hash of S[left:right]
    inline long long get(int l, int r, int id = 0) const {
        long long res = hash[id][r] - hash[id][l] * power[id][r-l] % mod[id];
        if (res < 0) res += mod[id];
        return res;
    }
    // get lcp of S[a:] and S[b:]
    inline int getLCP(int a, int b) const {
        int len = min((int)S_.size()-a, (int)S_.size()-b);
        int low = -1, high = len + 1;
        while (high - low > 1) {
            int mid = (low + high) / 2;
            if (get(a, a+mid, 0) != get(b, b+mid, 0)) high = mid;
            else if (get(a, a+mid, 1) != get(b, b+mid, 1)) high = mid;
            else low = mid;
        }
        return low;
    }
    // get lcp of S[a:] and T[b:]
    inline int getLCP(const RollingHash &t, int a, int b) const {
        int len = min((int)S_.size()-a, (int)S_.size()-b);
        int low = -1, high = len + 1;
        while (high - low > 1) {
            int mid = (low + high) / 2;
            if (get(a, a+mid, 0) != get(b, b+mid, 0)) high = mid;
            else if (get(a, a+mid, 1) != get(b, b+mid, 1)) high = mid;
            else low = mid;
        }
        return low;
    }
};


vector<long long> divisor(long long n) {
    vector<long long> res;
    for (long long i = 1LL; i*i <= n; ++i) {
        if (n%i == 0LL) {
            res.push_back(i);
            long long temp = n/i;
            if (i != temp) res.push_back(temp);
        }
    }
    sort(res.begin(), res.end());
    return res;
}
    
int main() {
    string str; cin >> str;
    int n = (int)str.size();
    vector<long long> divs = divisor(n);
    long long syuuki = n;
    for (auto d : divs) {
        bool ok = true;
        for (int j = 0; j + d < n; ++j) {
            if (str[j] != str[j+d]) ok = false;
        }
        if (ok) syuuki = min(syuuki, d);
    }
    if (syuuki == n) cout << 1 << endl << 1 << endl;
    else if (syuuki == 1) cout << n << endl << 1 << endl;
    else {
        vector<int> cannot_cut(n*2, 0);
        RollingHash rh(str);
        for (int d = 1; d < n; ++d) {
            if (cannot_cut[d]) continue;
            for (int dd = d; dd < n; dd += d) {
                if (rh.get(0, d) != rh.get(dd, dd+d)) break;
                cannot_cut[dd + d] = true;
            }
            for (int dd = n-d*2; dd >= 0; dd -= d) {
                if (rh.get(dd, dd+d) != rh.get(n-d, n)) break;
                cannot_cut[dd] = true;
            }
        }
        int con = 0;
        for (int i = 1; i < n; ++i) if (!cannot_cut[i]) ++con;
        cout << 2 << endl << con << endl;
    }
}

解法 3: Z-algorithm

今回 LCP を求めるときに

  • S[0:] と S[i:] との LCP

の形しか用いていないことから、Z-algorithm も使える。Z-algorithm を用いると、 O(N)

  • res[ i ] := S[0:] と S[i:] との LCP

を求めることができる。

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;


vector<int> Zalgo(const string &S) {
    int N = (int)S.size();
    vector<int> res(N);
    res[0] = N;
    int i = 1, j = 0;
    while (i < N) {
        while (i+j < N && S[j] == S[i+j]) ++j;
        res[i] = j;
        if (j == 0) {++i; continue;}
        int k = 1;
        while (i+k < N && k+res[k] < j) res[i+k] = res[k], ++k;
        i += k, j -= k;
    }
    return res;
}


vector<long long> divisor(long long n) {
    vector<long long> res;
    for (long long i = 1LL; i*i <= n; ++i) {
        if (n%i == 0LL) {
            res.push_back(i);
            long long temp = n/i;
            if (i != temp) res.push_back(temp);
        }
    }
    sort(res.begin(), res.end());
    return res;
}

int main() {
    string str; cin >> str;
    int n = (int)str.size();
    vector<long long> divs = divisor(n);
    long long syuuki = n;
    for (auto d : divs) {
        bool ok = true;
        for (int j = 0; j + d < n; ++j) {
            if (str[j] != str[j+d]) ok = false;
        }
        if (ok) syuuki = min(syuuki, d);
    }
    if (syuuki == n) cout << 1 << endl << 1 << endl;
    else if (syuuki == 1) cout << n << endl << 1 << endl;
    else {
        string str2 = str;
        reverse(str2.begin(), str2.end());
        auto lcp = Zalgo(str);
        auto lcp2 = Zalgo(str2);
        vector<int> cannot_cut(n*2, 0);
        for (int d = 1; d < n; ++d) {
            if (cannot_cut[d]) continue;
            for (int dd = d*2; dd < n; dd += d) {
                if (lcp[d] >= dd - d) cannot_cut[dd] = true;
                if (lcp2[d] >= dd - d) cannot_cut[n-dd] = true;
            }
        }
        int con = 0;
        for (int i = 1; i < n; ++i) if (!cannot_cut[i]) ++con;
        cout << 2 << endl << con << endl;
    }
}

解法 4: KMP 法

文字列検索とか、文字列周期とか、KMP 法が威力を発揮するイメージもある。特に文字列の周期を扱う問題で超強いイメージ。

KMP 法で直接求められるのは、

  • kmp[ i ] := S[0 : i] の suffix と、S の prefix が最大何文字まで一致するか (i 未満で)

である。これを  O(M) で求めることができる。これを利用すると

  • S[0 : i] の最小周期が i - kmp[ i ] で求められる

という特徴がある。ただし、S = "abcabcab" のときに、これは "abc" の繰り返しを途中で区切ることで作れるので周期は 3 となる。

これを利用して、

  • S[0 : i] や S[i : ] について、最小周期がその文字列の長さ以下で、かつ最小周期がその文字列の長さの約数になっているかどうか

を判定できるので、この問題を解くことができる。

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;


struct KMP {
    string pat;
    vector<int> fail;

    // construct
    KMP(const string &p) { init(p); }
    void init(const string &p) {
        pat = p;
        int m = (int)pat.size();
        fail.assign(m+1, -1);
        for (int i = 0, j = -1; i < m; ++i) {
            while (j >= 0 && pat[i] != pat[j]) j = fail[j];
            fail[i+1] = ++j;
        }
    }

    // the period of S[0:i]
    int period(int i) { return i - fail[i]; }
    
    // the index i such that S[i:] has the exact prefix p
    vector<int> match(const string &S) {
        int n = (int)S.size(), m = (int)pat.size();
        vector<int> res;
        for (int i = 0, k = 0; i < n; ++i) {
            while (k >= 0 && S[i] != pat[k]) k = fail[k];
            ++k;
            if (k == m) res.push_back(i-m+1);
        }
        return res;
    }
};


vector<long long> divisor(long long n) {
    vector<long long> res;
    for (long long i = 1LL; i*i <= n; ++i) {
        if (n%i == 0LL) {
            res.push_back(i);
            long long temp = n/i;
            if (i != temp) res.push_back(temp);
        }
    }
    sort(res.begin(), res.end());
    return res;
}

int main() {
    string str; cin >> str;
    int n = (int)str.size();
    vector<long long> divs = divisor(n);
    long long syuuki = n;
    for (auto d : divs) {
        bool ok = true;
        for (int j = 0; j + d < n; ++j) {
            if (str[j] != str[j+d]) ok = false;
        }
        if (ok) syuuki = min(syuuki, d);
    }
    if (syuuki == n) cout << 1 << endl << 1 << endl;
    else if (syuuki == 1) cout << n << endl << 1 << endl;
    else {
        vector<int> cannot_cut(n*2, 0);
        string str2 = str;
        reverse(str2.begin(), str2.end());
        KMP kmp1(str);
        KMP kmp2(str2);
        for (int d = 1; d < n; ++d) {
            if (kmp1.period(d) < d && d % kmp1.period(d) == 0) cannot_cut[d] = true;
            if (kmp2.period(d) < d && d % kmp2.period(d) == 0) cannot_cut[n-d] = true;
        }
        int con = 0;
        for (int i = 1; i < n; ++i) if (!cannot_cut[i]) ++con;
        cout << 2 << endl << con << endl;
    }
}