けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 050 D - Suffix Concat (試験管橙色)

めっちゃ面白い問題だった! Suffix Array の練習問題。

問題概要

英小文字からなる長さ  N の文字列  S が与えられる。

 S の suffix は空文字列を除いて  N 個ある。これらの suffix を適切な順序で連結させて 1 つの文字列を作るとき、それが辞書順最小となるような順序を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

まずこの問題の前提となる知見として、文字列  a_{0}, a_{1}, \dots, a_{N-1} を適切な順序で連結して辞書順最小になるようにする方法は、連結させたどの隣接する文字列  p, q に対しても

 p + q \lt q + p

が成り立つようにすることだ。このことの証明は次の記事にて。

drken1215.hatenablog.com

ここで、 S i 文字目以降の文字列 (S.substr(i) のこと) を  S_{i} と書くことにする。 S_{i} の長さは  |S_{i}| = N - i となる。

一般に 2 つの suffix  S_{i}, S_{j} に対して

 S_{i} + S_{j} \lt S_{j} + S_{i}

となる条件を考察しよう。簡単のため、 i \lt j と仮定する。 j \lt i の場合は比較関数をひっくり返せばよい。 i \lt j のとき、 |S_{i}| \gt |S_{j}| となることに注意しておく。

 S_{j} S_{i} の prefix になっていないとき

 S_{j} S_{i} の prefix になっているかどうかの判定は、lcp(i, j) == N-j かどうかによって判定できる。まず、 S_{j} S_{i} の prefix になっていない場合を考察する。

この場合は下図のように、 S_{i} + S_{j} S_{j} + S_{i} とを比べたときに、先頭から長さ  |S_{j}| 以内の範囲で、文字が異なる箇所がある。

よって、 S_{i} + S_{j} S_{j} + S_{i} の辞書順の大小関係は、単純に  S_{i} S_{j} の辞書順の大小関係に一致する。つまり、Suffix Array のランク配列を rank として、


rank[i] < rank[j]


によって判定できる。

 S_{j} S_{i} の prefix になっているとき

この場合が要注意だ。

この場合は  l = |S_{j}| (= N-j) として、下図で青く示したように、 S_{i} + S_{j} S_{j} + S_{i} の先頭の  l 文字と末尾の  l 文字は一致することになる。末尾が一致することについては suffix の定義から従う。

よって、先頭と末尾の  l 文字分を除外して残された部分は

  •  S_{i} + S_{j} S_{i+l}
  •  S_{i} + S_{j} S_{i} の先頭から  l 文字分

となる。よって

  •  S_{i+l} \gt S_{i} である場合は、確実に  S_{i} + S_{j} \gt S_{j} + S_{i} であると言える
  •  S_{i+l} \lt S_{i} である場合は、
    •  S_{i+l} \lt S_{i} の先頭から  l 文字分」である場合は  S_{i} + S_{j} \lt S_{j} + S_{i} である
    •  S_{i+l} = S_{i} の先頭から  l 文字分」である場合は、 S_{i} + S_{j} = S_{j} + S_{i} である

 =」になる場合が微妙ではあるが、まとめると

 S_{i+l} \lt S_{i}  \Leftrightarrow  S_{i} + S_{j} \le S_{j} + S_{i}

であると言える。よって、比較関数は


rank[i+l] < rank[i]


と実装すれば問題ない。

コード

上述の比較関数に基づいて、 0, 1, \dots, N-1 をソートすればよい。計算量は

  • Suffix Array を構築して rank を求める: O(N)
  • 比較関数の処理:lcp を求める部分に Sparse Table を用いると、前処理  O(N \log N)、各比較処理は  O(1) でできる
  • ソートの実行: O(N \log N) でできる

以上より、全体として  O(N \log N) の計算量で求められる。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int,int>;

// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
    void init(const vector<MeetSemiLattice> &vec) {
        int n = (int)vec.size(), h = 0;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    SparseTable<int> st;  // use for calcultating lcp(i, j)

    // getter
    int& operator [] (int i) {
        return sa[i];
    }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        vector<int> s;
        for (int i = 0; i < (int)str.size(); ++i) {
            s.push_back(str[i] + 1);
        }
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
        buildSparseTable();
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }
    
    // build sparse table for calculating lcp
    void buildSparseTable() {
        st.init(lcp);
    }

    // calc lcp of str.sutstr(a) and str.substr(b)
    int getLCP(int a, int b) {
        return st.get(min(rank[a], rank[b]), max(rank[a], rank[b]));
    }
};

int main() {
    int N;
    string S;
    cin >> N >> S;

    // Suffix Array の構築
    SuffixArray<string> suf(S);
    vector<int> rank = suf.get_rank();

    // suffix の比較
    function<bool (int, int)> cmp = [&](int i, int j) {
        // i < j を仮定できるようにする
        if (i == j) return true;
        if (i > j) return not cmp(j, i);

        // i と j の lcp
        int len = suf.getLCP(i, j);
        
        if (N-j > len) {
            // i と j が prefix の関係にないならば小さい順
            return rank[i] < rank[j];
        } else {
            // i の len 文字目以降と i を比較
            return rank[i+len] < rank[i];
        }
    };

    // ソート
    vector<int> ids(N);
    for (int i = 0; i < N; ++i) ids[i] = i;
    sort(ids.begin(), ids.end(), cmp);
    for (auto v : ids) cout << v+1 << endl;
}

Educational Codeforces Round 9 C. The Smallest String Concatenation (R1700)

すごくシンプルな面白い問題。

問題概要

 N 個の文字列  S_{0}, S_{1}, \dots, S_{N-1} が与えられる。

これらを並び替えて連結して 1 つの文字列を作る。作れる文字列のうち、辞書順最小のものを求めよ。

制約

  •  1 \le N \le 5 \times 10^{4}
  •  1 \le |S_{i}| \le 50

解法

単純に  S_{0}, S_{1}, \dots, S_{N-1} を辞書順にソートして、小さい順に連結するのでは反例がある。

たとえば  S_{0} = "ab"、 S_{1} = "abaab" のとき、

  • 辞書順に並び替えると  S_{0}, S_{1} で、この順に連結すると "ababaab" になる
  •  S_{1} を先に連結すると、"abaabab" になる

後者の方が辞書順として小さい。

a + b < b + a で定義

そこで、 N 個の文字列に対して、辞書順に代わる新たな順序を定義する。文字列  a, b に対して

 a \lt\lt b  \Leftrightarrow  a + b \lt b + a

と定義する。この順序は推移率を満たす。つまり、 a \lt\lt b かつ  b \lt\lt c ならば、 a \le\lt c が成り立つ。このことは英小文字からなる文字列を 26 進法の整数とみなすことで納得できる。

文字列  a, b を表す整数を  f(a), f(b) とすると、次のようになる。

  • 文字列  a + b を表す整数は  10^{|b|} f(a) + f(b)
  • 文字列  b + a を表す整数は  10^{|a|} f(b) + f(a)

よって、 a + b \lt b + a

 10^{|b|} f(a) + f(b) \lt 10^{|a|} f(b) + f(a)
 \Leftrightarrow \displaystyle \frac{f(a)}{10^{|a|}-1} \lt \frac{f(b)}{10^{|b|}-1}

と同値になる。ゆえに、文字列  a, b, c に対して、

  •  a \lt\lt b ならば、 \Leftrightarrow \displaystyle \frac{f(a)}{10^{|a|}-1} \lt \frac{f(b)}{10^{|b|}-1}
  •  b \le\lt c ならば、 \Leftrightarrow \displaystyle \frac{f(b)}{10^{|b|}-1} \lt \frac{f(c)}{10^{|c|}-1}

ということになり、 \displaystyle \frac{f(a)}{10^{|a|}-1} \lt \frac{f(c)}{10^{|c|}-1} が成り立つ。つまり、 a \lt\lt c が成り立つ。

推移率が成り立つので

推移率が成り立つので、上で定義した順序に従って、 N 個の文字列  S_{0}, S_{1}, \dots, S_{N} を並び替えることができる。結論から言えば、その順序に連結することで、連結後の文字列が辞書順最小となる。

以下のことを示そう。


 0, 1, \dots, N-1 の順列  p_{0}, p_{1}, \dots, p_{N-1} に対して、 S_{p_{0}} + S_{p_{1}} + \dots + S_{p_{N-1}} をこの順に連結して辞書順最小であるための必要十分条件は、各  i = 0, 1, \dots, N-1 に対して

 S_{p_{i}} \lt\lt S_{p_{i+1}} ( \Leftrightarrow  S_{p_{i}} + S_{p_{i+1}} \lt S_{p_{i+1}} + S_{p_{i}})

が成り立つことである。


必要条件であることは明らか (満たさなければ改善できることが明らか) なので、十分条件であることが示せればよい。

簡単のため、 S_{0} \lt\lt S_{1} \lt\lt \dots \lt\lt S_{N-1} のとき、任意の順列  q_{0}, q_{1}, \dots, q_{N-1} に対して、 S_{0} + S_{1} + \dots + S_{N-1} \lt S_{q_{0}} + S_{q_{1}} + \dots + S_{q_{N-1}} が成立することを言えばよい。

任意の順列が隣り合う要素を swap することで作れることから言える。ここではイメージを掴むために、具体例で見ていくことにする。一般の順列に対しても同様に示せる。

たとえば  N = 4 として、 S_{0} + S_{1} + S_{2} + S_{3} \lt S_{2} + S_{1} + S_{3} + S_{0} が成り立つことを示そう。

 S_{0} + S_{1} + S_{2} + S_{3}
 \lt S_{0} + S_{2} + S_{1} + S_{3}
 \lt S_{2} + S_{0} + S_{1} + S_{3}
 \lt S_{2} + S_{1} + S_{0} + S_{3}
 \lt S_{2} + S_{1} + S_{3} + S_{0}

というように示せる。

コード

 N 個の文字列  S_{0}, S_{1}, \dots, S_{N-1} を上述の順序に従って並び替えて、その順に連結すればよい。

計算量は  M = \max_{i} |S_{i}| として、 O(MN\log N) と評価できる。

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<string> S(N);
    for (int i = 0; i < N; ++i) cin >> S[i];

    // S を並び替え
    auto cmp = [&](const string& a, const string& b) -> bool {
        return a + b < b + a;
    };
    sort(S.begin(), S.end(), cmp);

    // 連結
    string res = "";
    for (auto s : S) res += s;
    cout << res << endl;
}

AOJ 2644 Longest Match (JAG 夏合宿 2014 day4-F) (700 点)

Suffix Array の練習問題

問題概要

英小文字からなる文字列  S と、 Q 個のクエリが与えられる。

各クエリでは 2 つの文字列  x, y が与えられる。

文字列  S の連続する部分文字列であって、 x で始まり、 y で終わるものの中で、最長の長さを答えよ。ただし、そのような部分文字列が存在しない場合には 0 と答えよ。

制約

  •  1 \le |S| \le 2 \times 10^{5}
  •  1 \le Q \le 10^{5}
  • 各クエリで与えられる文字列  x,  y の長さの合計は  2 \times 10^{5} 以下

考えたこと

Suffix Array を用いて文字列検索する手法がある。それを応用して今回の問題は解ける。なお、Suffix Array を構築したあとでは、文字列  T を検索するのに要する計算量は  O(|T| \log |S|) であり、非常に高速である。

そもそも  x, y S に含まれていなければ、答えは 0 としてよい。

まず、文字列  x から始まる suffix のうち、最小の添字 xres を求める。Suffix Array 上では、文字列  x から始まる suffix は連続した区間になる。その区間の左端と右端は Suffix Array の lower_bound によって求められる (蟻本 P.338 参照)。その区間内での、文字列  S の index の最小値を求めれば良い。これは RMQ を用いて高速に求められる。

次に、文字列  y から始まる suffix のうち、最大の添字 yres を求める。これも同様に RMQ を用いて求められる。

求める答えは次の通り。

  • xres > yres の場合:答えは 0
  • xres + x.size() > yres + y.size() の場合:答えは 0
  • それ以外:答えは yres - xres + y.size();

コード

計算量は、クエリに登場する文字列の長さの合計値を  M として、

  • Suffix Array の構築: O(|S|)
  • 各クエリの処理: O(M \log |S|)

となる。

#include <bits/stdc++.h>
using namespace std;

// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
    void init(const vector<MeetSemiLattice> &vec) {
        int n = (int)vec.size(), h = 0;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    SparseTable<int> st;  // use for calcultating lcp(i, j)

    // getter
    int& operator [] (int i) {
        return sa[i];
    }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        vector<int> s;
        for (int i = 0; i < (int)str.size(); ++i) {
            s.push_back(str[i] + 1);
        }
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
        buildSparseTable();
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }
    
    // build sparse table for calculating lcp
    void buildSparseTable() {
        st.init(lcp);
    }

    // calc lcp of str.sutstr(a) and str.substr(b)
    int getLCP(int a, int b) {
        return st.get(min(rank[a], rank[b]), max(rank[a], rank[b]));
    }
};

void solveAOJ2644() {
    // Suffix Array の構築
    string S;
    cin >> S;
    SuffixArray<string> suf(S);
    vector<int> sa = suf.get_sa();

    // Suffix Array の区間に関する RMQ
    SparseTable<int> min_st(sa);
    for (auto& v : sa) v = -v;
    SparseTable<int> max_st(sa);

    auto solve = [&](const string& x, const string& y) -> int {
        if (!suf.is_contain(x) || !suf.is_contain(y)) {
            return 0;
        }
        int xl = suf.lower_bound(x);
        int xr = suf.upper_bound(x);
        int yl = suf.lower_bound(y);
        int yr = suf.upper_bound(y);
        int xres = min_st.get(xl, xr);
        int yres = -max_st.get(yl, yr);

        if (xres > yres || xres + x.size() > yres + y.size()) 
            return 0;
        else 
            return yres - xres + (int)y.size();   
    };

    int Q;
    cin >> Q;
    while (Q--) {
        string x, y;
        cin >> x >> y;
        cout << solve(x, y) << endl;
    }
}

int main() {
    solveAOJ2644();
}

DISCO presents 2016 予選 D - DDPC特別ビュッフェ (赤色)

久しぶりに bit ベクター高速化を使った。デバッグがしんどかった。

問題概要

長さ  N の数列  A_{1}, A_{2}, \dots, A_{N} と、長さ  M の数列  B_{1}, B_{2}, \dots, B_{M} が与えられる。今、次の操作をちょうど  K 回実行する

  •  1 \le i \le N 1 \le j \le M を選ぶ
  •  A_{i} B_{j} とを swap する

操作実行後の  \displaystyle \bigl(\sum_{i=1}^{N} A_{i}\bigr)\bigl(\sum_{i=1}^{M} B_{i}\bigr) の値の最大値を求めよ。

制約

  •  1 \le N, M \le 55
  •  1 \le K \le 999
  •  0 \le A_{i}, B_{i} \le 22222

考えたこと

どのように操作しても、 A の総和と  B の総和の和は一定 ( S とおく) なので、 \displaystyle \bigl(\sum_{i=1}^{N} A_{i}\bigr)\bigl(\sum_{i=1}^{M} B_{i}\bigr) を最大にするためには、 \displaystyle \bigl(\sum_{i=1}^{N} A_{i}\bigr) \displaystyle \bigl(\sum_{i=1}^{N} B_{i}\bigr) の差を最小にしたい。

言い換えれば、操作後の  \displaystyle \bigl(\sum_{i=1}^{N} A_{i}\bigr) の値をできるだけ  \frac{S}{2} に近づけたい。部分和問題っぽい雰囲気なので DP で解けそうだ。いつもの DP によって、

  • adp[k][v] ← 数列  A から  k 個選んで総和を  v にできるかどうか
  • bdp[k][v] ← 数列  B から  k 個選んで総和を  v にできるかどうか

が求められる。これらを求めるのに計算量は  O((N^{2}+M^{2})S) を要するので通常は間に合わない。そこで常套手段として、bitset を用いた高速化をする。

集計

基本的には adp[i][x] が True となる x と bdp[N-i][y] が True となる y の和 x + y が  \frac{S}{2} に近づくように集計する。

ここで、

  •  K = 1 のときは、 i = N-1, N-i = 1 であることが条件
  •  K > 1 のときは、 N-i \le \min(K, M) であることが条件

ということに注意しておく。

上記の条件を満たす各  i に対して、adp[i][x] == True であるような x に対して、

  • S/2 以下の y であって bdp[N-i][y] == Trueとなる最大の y
  • (S+1)/2 以上の y であって bdp[N-i][y] == Trueとなる最小の y

をそれぞれ求めて、(x+y)(S-x-y) の最大値を求めていく。この部分の計算量は  O(NS) となる。十分間に合う。

コード

#include <bits/stdc++.h>
using namespace std;
constexpr long long MAX = 1500000;

// 前処理の DP
vector<bitset<MAX>> enum_sum(const vector<int>& A) {
    int N = A.size();
    vector<bitset<MAX>> dp(N+1);
    dp[0][0] = true;
    for (int i = 0; i < N; ++i) {
        vector<bitset<MAX>> ndp(N+1);
        for (int j = 0; j <= N; ++j) {
            // 選ぶ
            if (j < N) ndp[j+1] |= dp[j] << A[i];

            // 選ばない
            ndp[j] |= dp[j];
        }
        swap(dp, ndp);
    }
    return dp;
}

// 集合 a と集合 b から要素を選んで和を S/2 に近づけたい
long long solve(const bitset<MAX>& a, const bitset<MAX>& b, long long S) {
    long long res = 0;

    // left[v] := (x <= v かつ b[x] = true となる最大の x)
    // right[v] := (x >= v かつ b[x] = true となる最小の x)
    vector<long long> left(MAX, -1), right(MAX, -1);
    for (int i = 0; i < MAX; ++i) {
        if (b[i]) left[i] = i;
        else if (i-1 >= 0) left[i] = left[i-1];
    }
    for (int i = MAX-1; i >= 0; --i) {
        if (b[i]) right[i] = i;
        else if (i+1 < MAX) right[i] = right[i+1];
    }

    // 集合 a の値 v に応じて、最適な b を選ぶ
    for (int v = 0; v < MAX; ++v) {
        if (!a[v]) continue;
        long long bv1 = min(max(S/2-v, 0LL), MAX);
        long long bv2 = min(max((S+1)/2-v, 0LL), MAX);
        if (left[bv1] != -1)
            res = max(res, (v + left[bv1]) * (S - v - left[bv1]));
        if (right[bv2] != -1)
            res = max(res, (v + right[bv2]) * (S - v - right[bv2]));
    }
    return res;
}

int main() {
    // 入力
    int N, M, K;
    cin >> N >> M >> K;
    vector<int> A(N), B(M);
    long long S = 0;
    for (int i = 0; i < N; ++i) { cin >> A[i]; S += A[i]; }
    for (int i = 0; i < M; ++i) { cin >> B[i]; S += B[i]; }

    // adp[i] := A から i 個選んで作れる数を表す bitset
    // bdp[j] := B から j 個選んで作れる数を表す bitset
    const auto& adp = enum_sum(A);
    const auto& bdp = enum_sum(B);

    // adp[i] と bdp[N-i] から S/2 に近い数を作る
    long long res = 0;
    for (int i = 0; i <= N; ++i) {
        if (K == 1 && i == N) continue;
        if (N-i > min(K, M)) continue;
        res = max(res, solve(adp[i], bdp[N-i], S));     
    }
    cout << res << endl;
}

DISCO presents 2016 予選 B - ディスコ社内ツアー (青色)

シミュレーションの仕方を工夫する問題。数値ごとに index をまとめあげるデータを持つとうまくいく。

問題概要

長さ  N の正の整数列  A_{1}, A_{2}, \dots, A_{N} が与えられる。

 1, 2, \dots, N の順列  p_{1}, p_{2}, \dots, p_{N} であって、

 A_{p_{1}} \le A_{p_{2}} \le \dots \le A_{p_{N}}

を満たすものを考える。そのような順列の  p_{i} \gt p_{i+1} (ただし  p_{N} = 1 である場合の  i = N-1 の場合は除外) となっている箇所の個数の最小値を求めよ。

制約

  •  2 \le N \le 10^{5}
  •  1 \le A_{i} \le 10^{5}

考えたこと

本当に愚直にやると TLE になる。

  • ids[v]A[i] == v となる index i の集合

というデータをもつと上手に処理できる。僕の実装だと、 M = \max_{i} A_{i} として、計算量は  O(N + M \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;
const int MAX = 110000;

int main() {
    vector<vector<int>> ids(MAX);

    // 入力
    int N;
    cin >> N;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) {
        cin >> A[i];
        ids[A[i]].push_back(i);
    }

    int res = 0;
    int cur = 0;
    for (int v = 0; v < MAX; ++v) {
        if (ids[v].empty()) continue;
        int pl = lower_bound(ids[v].begin(), ids[v].end(), cur) - ids[v].begin();
        if (cur > ids[v][0]) cur = ids[v].back();
        else {
            ++res;
            cur = ids[v][pl-1];
        }
    }
    if (cur > 0) ++res;
    cout << res << endl;
}