けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 286 C - Rotate and Palindrome (茶色, 300 点)

慣れれば解ける問題だけど、最初は「固定する」という考え方が難しいかもしれない。

問題概要

長さ  N の文字列  S が与えられる。この文字列に対して、次の操作を繰り返すことで回文にしたい。

  • 先頭の文字を末尾に移動する (コスト  A)
  • 文字を 1 つ変更する (コスト  B)

回文にするための最小コストを求めよ。

解法

まともに考えると、色んなパターンがありそうで頭がこんがらがってしまう!!!

こういうときは「操作の流れを単純化できないか」を考えるとよかったりする。今回は、操作の流れを次の 2 ステップに分けて考えとうまくいく。


  • Step 1:先頭の文字を何文字か、末尾に移動する
  • Step 2:その結果できた文字列に対して、前半の文字を適切に置き換えることで回文に一致するようにする

つまり、一度コスト A の操作をやり切ったあとは、コスト B の操作のみに専念すればよいということだ。

そのように考えて良い理由

文字を置き換えてから、先頭の文字を末尾に移動するという 2 回の操作は、

先に先頭の文字を末尾に移動してから、上述の文字を置き換える操作

に置き換えても結果は等しくなる。よって、任意の操作列は、その最終結果を変更することなく、「先に先頭の文字を末尾に移動する操作をすべてやり切るような操作列」に変形できる。

コスト A の操作回数を固定する

以上の考察で考えやすくなった。あとは、コスト A の操作回数  i を固定して考えよう。各  i についての最小コストをそれぞれ求めて、その最小値をとればよいのだ。

さて、コスト A の操作回数を固定すると、次の問題を解けばよいことになる。


文字列  T ( S の先頭  i 文字を末尾に移動したもの) が与えられる。

 T の文字をいくつか変えることで回文にしたい。1 文字変えるごとにコスト  B を消費する。最小コストを求めよ。


これは下図のように、左から  i 文字目と右から  i 文字目が異なっている箇所の個数を数えればよい。これは  O(N) の計算量で実行できる。

全体としては、計算量は  O(N^{2}) となる。

コード

Python

# 入力
N, A, B = map(int, input().split())
S = input()

# 先頭の文字を移す回数を固定して考える
res = 1<<60
for i in range(N):
    # 先頭 i 個を末尾に移す場合の最終コストを求める
    cost = 0
    for j in range(N//2):
        if S[j] != S[N-j-1]:
            cost += B
    res = min(res, cost + A * i)

    # 先頭の文字を末尾に移動
    S = S[1:] + S[0]

print(res)

C++

#include<bits/stdc++.h>
using namespace std;

int main() {
    // 入力
    long long N, A, B;
    string S;
    cin >> N >> A >> B >> S;

    // 先頭の文字を移す回数を固定して考える
    long long res = 1LL<<60;
    for (int i = 0; i < N; ++i) {
        // 先頭 i 個を末尾に移す場合の最終コストを求める
        long long cost = 0;
        for (int j = 0; j < N/2; ++j) {
            if (S[j] != S[N-j-1]) cost += B;
        }
        res = min(res, cost + A * i);

        // 先頭の文字を末尾に移動
        S = S.substr(1) + S[0];
    }
    cout << res << endl;
}

JOI 二次予選 2023 A - 年齢の差 (難易度 3)

シンプルながらも、学べるポイントがたくさんある問題ですね

問題概要

JOI 市には  1 から  N までの番号が付けられた  N 人の住民がいて、住民  i ( 1 \le i \le N) の年齢は  A_{i​} 歳です。

JOI 市の住民の年齢  A_{1}​, A_{2}​, \dots, A_{N}​ が与えられます。 i = 1, 2, \dots, N に対して、住民  i と他の住民との年齢の差の最大値を求めるプログラムを作成してください。

制約

  •  2 \le N \le 250000
  •  0 \le A_{i} \le 10^{9}

解法

ステップ 1:まず 0-indexed にする

JOI の問題文では、与えられる配列の添字が 1 始まりであることがよくあります。今回も  N 人の住民の年齢が  A_{1}, A_{2}, \dots, A_{N} というように、1 始まりになっています。

しかし C++ や Python をはじめ、多くのプログラミング言語では、配列の先頭の添字は 0 です。そこで多くの場合、問題の添字を 0 始まりになるように読み替えるとよいでしょう。そうすると、次のようになります。


JOI 市には  0 から  N-1 までの番号が付けられた  N 人の住民がいて、住民  i ( 0 \le i \le N-1) の年齢は  A_{i​} 歳です。

JOI 市の住民の年齢  A_{0}​, A_{1}​, \dots, A_{N-1}​ が与えられます。 i = 0, 1, \dots, N-1 に対して、住民  i と他の住民との年齢の差の最大値を求めるプログラムを作成してください。


もとの問題文と変わった部分に注目してみてください。なお、このように、0 始まりの添字を 0-indexed と呼びます。

ステップ 2:問題の意味を把握する

まずは問題の意味を頑張って把握しましょう。もしかしたら、「 i = 0, 1, \dots, N-1 に対して、〜してください」という趣旨の問題を初めて解く方も多いかもしれないですね。

一瞬「?」が飛ぶ人も多いかもしれませんが、難しくはありません。次のようにすればよいです。

for (int i = 0; i < N; ++i) {
    // i についての問題を解く

}

まずは、やってみましょう。今回は「 i についての問題」とは、次のような問題です。


  •  A_{0} A_{i} の差
  •  A_{1} A_{i} の差
  • ...
  •  A_{N-1} A_{i} の差

をそれぞれ求めて、その最大値を求めてください


以上を踏まえると、次のようなコードが書けそうです。

#include <iostream>
#include <vector>
using namespace std;

int main() {
    // 入力
    int N;
    cin >> N;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // 各 i についての問題を解く
    for (int i = 0; i < N; ++i) {
        int res = 0;

        // A[i] と、A{0], A[1], ..., A[N-1] との差の最大値を求める
        for (int j = 0; j < N; ++j) {
            res = max(res, abs(A[i] - A[j]));
        }
        cout << res << endl;
    }
}

この解法で部分点を獲得できます。具体的には 100 点中 55 点を獲得できます。

しかし全体としては TLE (Time Limit Exceeded) という判定になるはずです。プログラムの実行に時間がかかりすぎて、計算実行制限時間である 1 秒以内に処理を終えられないということです。なぜそのようになるかを考えてみましょう。

上のコードを見ると、for 文が 2 重になっています。

  • 1 個目の for 文で  i = 0, 1, \dots, N-1 について考えていて
  • それぞれについて 2 個目の for 文で  j = 0, 1, \dots, N-1 について考えている

といった具合になっています。よって、for 文で扱っている添字  i, j の組として考えられる値は  N \times N = N^{2} 通りとなります。

つまり、上記のコードは  N^{2} に比例する計算時間を要するということになります。このことを、計算量 O(N^{2}) であるといいます。

このように計算時間を  O() という記法で表したものを計算量と言います。今後より難しい問題を解いていくためには、計算量について理解を深めていくことが重要です。

なお、コンピュータが 1 秒間に処理できる計算ステップ回数は  10^{9} 回程度と言われています。今回の問題では  N \le 250000 という制約があるため、

 N^{2} \le 62500000000 \simeq 6 \times 10^{10}

となります。このことから、 O(N^{2}) の計算量をもつプログラムは、実行制限時間である 1 秒以内に処理を終えられないことがわかります。

なお、計算量についてより詳しくは次の記事で勉強してみてください。

qiita.com

ステップ 3:探索候補を絞って高速化

そこで、プログラムを高速化しましょう。たとえば、 A = (3, 5, 6, 9, 11, 15, 17, 21) という場合を考えてみます。

このとき、

  • A[0] = 3 との差が最大の要素は 21 で、差は 21 - 3 = 18
  • A[1] = 5 との差が最大の要素は 21 で、差は 21 - 5 = 16
  • A[2] = 6 との差が最大の要素は 21 で、差は 21 - 6 = 15
  • A[3] = 9 との差が最大の要素は 21 で、差は 21 - 9 = 12
  • A[4] = 11 との差が最大の要素は 21 で、差は 21 - 11 = 10
  • A[5] = 15 との差が最大の要素は 3 で、差は 15 - 3 = 12
  • A[6] = 17 との差が最大の要素は 3 で、差は 17 - 3 = 14
  • A[7] = 21 との差が最大の要素は 3 で、差は 21 - 3 = 18

というようになっています。気付くことは、A[i] との差が最大になるのは、「A の最大値である 21」と「A の最小値である 3」のどちらかだということです。

このことは一般に言えます。つまり、一般の配列 A に対しても、

  • ma ← 配列 A の最大値
  • mi ← 配列 A の最小値

として、A[i] との差の最大値は

max(ma - A[i], A[i] - mi)

と求められます。最初の解法とは異なり、この計算はとても高速に実行できます。

計算量

最後に、この解法の計算量を見積もってみましょう。

  • mami を求めるのに要する計算量: O(N)
  • A[i] に対して、max(ma - A[i], A[i] - mi) を求める計算量: O(N)

よって、全体の計算量も  O(N) となります。

コード

#include <iostream>
#include <vector>
using namespace std;

int main() {
    // 入力
    int N;
    cin >> N;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // A の最小値と最大値を求める
    int mi = A[0], ma = A[0];
    for (int i = 0; i < N; ++i) {
        mi = min(mi, A[i]);
        ma = max(ma, A[i]);
    }

    // 各 i についての問題を解く
    for (int i = 0; i < N; ++i) {
        cout << max(ma - A[i], A[i] - mi) << endl;
    }
}

AtCoder ABC 281 Ex - Alchemy (赤色, 600 点)

この平方分割のやり方はちゃんとマスターしたい

問題概要

 A 種類のレベル 1 の宝石がある。各種類のレベル 1 の宝石は無限個ある。

 n 個の宝石を合成することで、レベル  n の宝石を作ることができる。ただし、その  n 個の宝石は次の条件を満たす必要がある:

  • すべてレベル  n 未満である
  • 2 以上の整数  x については「レベル  x の宝石」は高々 1 個である
  • どの 2 個も種類が異なる (レベル 2 以上の宝石の種類が異なるとは、原料となった宝石の種類が異なることを意味する)

レベル  N の宝石を何種類作れるかを 998244353 で割ったあまりで答えよ。

制約

  •  1 \le N \le 2 \times 10^{5}
  •  1 \le A \le 10^{9}

考えたこと

いきなりレベル  N の宝石の種類数を数えるのではなく、レベル  2, 3, \dots の宝石の種類数を数えることにする。なお、dp[i] を、レベル  i の種類数とする

レベル 1

 A 種類である

レベル 2

dp[2] =  {}_{A}\mathrm{C}_{2} 種類である

レベル 3
  • レベル 1 が 3 個の場合: {}_{A}\mathrm{C}_{3} 種類
  • レベル 1 が 2 個、レベル 2 が 1 個の場合: {}_{A}\mathrm{C}_{2} \times dp[2] 種類
レベル 4
  • レベル 1 が 4 個の場合: {}_{A}\mathrm{C}_{4} 種類
  • レベル 1 が 3 個、レベル 2 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[2] 種類
  • レベル 1 が 3 個、レベル 3 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[3] 種類
  • レベル 1 が 2 個、レベル 2 が 1 個、レベル 3 が 1 個の場合: {}_{A}\mathrm{C}_{3} \times dp[2]  \times dp[3] 種類

畳み込みへ

ここまで考えてみて「畳み込みっぽい」式になりそうだとわかる。

  • レベル 1 を表す多項式: f(x) = 1 +  {}_{A}\mathrm{C}_{1}x + {}_{A}\mathrm{C}_{2}x^{2} +  {}_{A}\mathrm{C}_{3}x^{3} + \dots + x^{A}
  • レベル 2 を表す多項式: h_{1}(x) = 1 + dp[1] x
  • レベル 3 を表す多項式: h_{2}(x) = 1 + dp[2] x^{2}
  • ...
  • レベル  n-1 を表す多項式: h_{n-1}(x) = 1 + dp [n-1] x^{n-1}

としたとき、

dp[n] =  \lbrack x^{n} \rbrack f(x)h_{1}(x)h_{2}(x) \dots h_{n-1}(x)

と表せることになる。ここまでで  O(N^{2}) 解法にはなった。具体的には、

for (int i = 1; i <= N; ++i) {
    dp[i] = f[i];
    h[i] = {1, dp[i]};  // h[i] = 1 + dp[i]x
    f *= h[i];
}

というように処理していく。

もし仮に、各  h_{i}(x) の係数が静的に決まるならば、よくある「二分木のような計算順序」によって、 O(N (\log N)^{2}) の計算量となる。

しかし今回は、 h_{i}(x) の係数が  h_{1}(x), h_{2}(x), \dots, h_{i-1}(x) によって動的に決まるため、難しいな......という気持ちになっていた。以前に分割統治 FFT というテクニックを学んでいて、それが使えないかと考えたけどわからなかった。

公式解説は分割統治 FFT

平方分割

公式解説はちょっと天才だと思った。でもコンテスト本番中に通されたコードはほとんど平方分割解法だった。それは無理ない解法だと感じたので、そっちをマスターすることにした。

毎回 f *= h[i] という多項式演算をする代わりに、多項式列  h_{1}, h_{2}, \dots h_{i} の後ろの方を一時的に計算して貯めておくための多項式  g を用意する方法が考えられる。

このとき、dp[i] の値は、 f \times g i 次の項を計算すればよいわけだが、その計算量は  \mathrm{deg}(g) で抑えられるのだ。 B = O(\sqrt{N}) 回ごとに、f *= g をすることにした場合、

  • 毎回の dp[i] を求める計算量: O(\sqrt{N})
  • f *= g の計算量: O(\sqrt{N}) 回 ×  O(N \log{N}) =  O(N \sqrt{N} \log{N})

ということで、全体の計算量は  O(N \sqrt{N} \log{N}) となる。

なお、 B = \sqrt{N \log N} とするともう少し早くなって、 O(N \sqrt{N \log N}) になる。

コード

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/convolution>
using namespace std;
using mint = atcoder::modint998244353;

// v に (x + a) をかける
void update(vector<mint>& v, int& deg, mint a) {
    ++deg;
    for (int d = deg; d >= 1; --d) v[d] += v[d-1] * a;
}
int main() {
    long long N, A;
    cin >> N >> A;

    // f, g
    vector<mint> f(N + 1, 0);
    f[0] = 1;
    for (int i = 1; i <= min(N, A); ++i) f[i] = f[i-1] * (A-i+1) / i;

    // 貯める計算過程を表す多項式を初期化する関数
    int B = 2000;
    vector<mint> tmp;
    int deg = 0;
    auto init = [&]() -> void {
        tmp.assign(B, 0), tmp[0] = 1, deg = 0;
    };

    // 計算過程を貯めながら計算する
    mint a = A;
    init();
    for (int n = 2; n <= N; ++n) {
        a = 0;
        for (int d = max(0, (int)(n+1-f.size())); d < tmp.size(); ++d) {
            if (n - d < 0) break;
            a += tmp[d] * f[n - d];
        }
        update(tmp, deg, a);
        if (deg == tmp.size() - 1) {
            f = atcoder::convolution(f, tmp);
            f.resize(N + 1);
            init();
        }
    }
    cout << a.val() << endl;
}

AtCoder ABC 280 Ex - Substring Sort (銅色, 600 点)

Suffix Tree 上で DFS した。デバッグがめっちゃ大変だった!

問題概要

 N 個の英小文字からなる文字列  S_{1}, S_{2}, \dots, S_{N} が与えられる。

これら  N 個の文字列について、連続する部分文字列をすべて考える。重複を除かずに考えると  M = \displaystyle \sum_{i=1}^{N} \frac{|S_{i}|(|S_{i}|+1)}{2} 個ある。

これら  M 個の文字列を辞書順にソートしたい。 Q 個の整数  1 \le x_{1} \lt x_{2} \lt \dots \lt x_{Q} \le M が与えられるので、各  q に対して、辞書順で  x_{q} 番目の文字列を求めよ。

具体的には、その文字列が  S_{K} L 文字目から  R 文字目までに一致するとき、3 つの整数組  (K, L, R) を出力せよ。複数通りある場合はどれを出力してもよい。

制約

  •  1 \le N \le 10^{5}
  •  \sum_{i=1}^{N}|S_{i}| \le 10^{5}
  •  1 \le Q \le 2 \times 10^{5}

考えたこと

文字列  S_{1}, S_{2}, \dots, S_{N} を "$" で挟んで連結した文字列  T の Suffix Tree を考える。ただし、各 suffix において、"$" 以降は無視するようにする。

たとえば、 S = ("aabb", "abababc") のとき、 T = "aabb$abababc" となり、 T の Suffix Array は次のようになる。

  • 0 番目:""
  • 1 番目:"$abababc"
  • 2 番目:"aabb$abababc"
  • 3 番目:"abababc"
  • 4 番目:"ababc"
  • 5 番目:"abb$abababc"
  • 6 番目:"abc"
  • 7 番目:"b$abababc"
  • 8 番目:"bababc"
  • 9 番目:"babc"
  • 10 番目:"bb$abababc"
  • 11 番目:"bc"
  • 12 番目:"c"

よって、 T の Suffix Tree は下図のようになる。"$" 以降を無視する操作は、高さ配列 lcp の値を適切に変形しておくことで実現できる。

求める文字列の辞書順は、この木の各ノードを dfs することで得られる。上図の場合、次のようになる。

  • "a", "a", "a", "a", "a"
  • "aa"
  • "aab"
  • "aabb"
  • "ab", "ab", "ab", "ab"
  • "aba", "aba"
  • "abab", "abab"
  • "ababa"
  • "ababab"
  • "abababc"
  • "ababc"
  • "abb"
  • "abc"
  • "b", "b", "b", "b", "b"
  • "ba", "ba"
  • "bab", "bab"
  • "baba"
  • "babab"
  • "bababc"
  • "babc"
  • "bb"
  • "bc"
  • "c"

各ノードは、suffix array の左端 left、右端 right と、深さ depth の 3 つの値で管理することにした。

計算量を評価する。Suffix Array のノード数は、 D = \sum_{i=1}^{N} |S_{i}| として  O(D) となる。下の実装では、区間 [left, right) を再帰的に分割していく部分で、高さ配列 lcp を乗せたセグメント木を用いているため、計算量は全体として  O(D \log D + Q) となる。

コード

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
using pll = pair<long long, long long>;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// Segment Tree
template<class Monoid> struct SegTree {
    using Func = function<Monoid(Monoid, Monoid)>;
    int N;
    Func F;
    Monoid IDENTITY;
    int SIZE_R;
    vector<Monoid> dat;

    /* initialization */
    SegTree() {}
    SegTree(int n, const Func f, const Monoid &identity)
    : N(n), F(f), IDENTITY(identity) {
        SIZE_R = 1;
        while (SIZE_R < n) SIZE_R *= 2;
        dat.assign(SIZE_R * 2, IDENTITY);
    }
    void init(int n, const Func f, const Monoid &identity) {  
        N = n;
        F = f;
        IDENTITY = identity;
        SIZE_R = 1;
        while (SIZE_R < n) SIZE_R *= 2;
        dat.assign(SIZE_R * 2, IDENTITY);
    }
    
    /* set, a is 0-indexed */
    /* build(): O(N) */
    void set(int a, const Monoid &v) { dat[a + SIZE_R] = v; }
    void build() {
        for (int k = SIZE_R - 1; k > 0; --k)
            dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* update a, a is 0-indexed, O(log N) */
    void update(int a, const Monoid &v) {
        int k = a + SIZE_R;
        dat[k] = v;
        while (k >>= 1) dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* get [a, b), a and b are 0-indexed, O(log N) */
    Monoid get(int a, int b) {
        Monoid vleft = IDENTITY, vright = IDENTITY;
        for (int left = a + SIZE_R, right = b + SIZE_R; left < right; 
        left >>= 1, right >>= 1) {
            if (left & 1) vleft = F(vleft, dat[left++]);
            if (right & 1) vright = F(dat[--right], vright);
        }
        return F(vleft, vright);
    }
    Monoid all_get() { return dat[1]; }
    Monoid operator [] (int a) { return dat[a + SIZE_R]; }
    
    /* get max r that f(get(l, r)) = True (0-indexed), O(log N) */
    /* f(IDENTITY) need to be True */
    int max_right(const function<bool(Monoid)> f, int l = 0) {
        if (l == N) return N;
        l += SIZE_R;
        Monoid sum = IDENTITY;
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(F(sum, dat[l]))) {
                while (l < SIZE_R) {
                    l = l * 2;
                    if (f(F(sum, dat[l]))) {
                        sum = F(sum, dat[l]);
                        ++l;
                    }
                }
                return l - SIZE_R;
            }
            sum = F(sum, dat[l]);
            ++l;
        } while ((l & -l) != l);  // stop if l = 2^e
        return N;
    }

    /* get min l that f(get(l, r)) = True (0-indexed), O(log N) */
    /* f(IDENTITY) need to be True */
    int min_left(const function<bool(Monoid)> f, int r = -1) {
        if (r == 0) return 0;
        if (r == -1) r = N;
        r += SIZE_R;
        Monoid sum = IDENTITY;
        do {
            --r;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(F(dat[r], sum))) {
                while (r < SIZE_R) {
                    r = r * 2 + 1;
                    if (f(F(dat[r], sum))) {
                        sum = F(dat[r], sum);
                        --r;
                    }
                }
                return r + 1 - SIZE_R;
            }
            sum = F(dat[r], sum);
        } while ((r & -r) != r);
        return 0;
    }
    
    /* debug */
    void print() {
        for (int i = 0; i < N; ++i) {
            cout << (*this)[i];
            if (i != N-1) cout << ",";
        }
        cout << endl;
    }
};

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)

    // getter
    int& operator [] (int i) {
        return sa[i];
    }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray() {}
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        vector<int> s;
        for (int i = 0; i < (int)str.size(); ++i) {
            s.push_back(str[i] + 1);
        }
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }

    // debug
    void dump() {
        cout << str << endl;
        for (int i = 0; i < sa.size(); ++i) {
            cout << i << ": " << sa[i] << ", " << str.substr(sa[i]) << endl;
        }
    }
};

int main() {
    // 入力
    int N, Q;
    cin >> N;
    vector<string> S(N);
    string T;  // S を '$' を挟みながら連結したもの
    vector<int> lp(N), rp(N);  // T における S[i] 部分の始点と終点
    for (int i = 0; i < N; ++i) {
        cin >> S[i];
        if (i) T += "$";
        lp[i] = T.size(), T += S[i], rp[i] = T.size();
    }
    cin >> Q;
    vector<long long> x(Q);
    for (int i = 0; i < Q; ++i) cin >> x[i];

    // Suffix Array の構築
    SuffixArray<string> suf(T);
    vector<int> sa = suf.get_sa();
    vector<int> rank = suf.get_rank();
    vector<int> lcp = suf.get_lcp();

    // $ の影響を除くための処理
    // rem_len[i] := sa[i] を表す文字列の先頭から $ までの残り文字数
    // sid[i] := sa[i] を表す文字列に対応する S の id
    // sstart[i] := sa[i] を表す文字列が S[sid] の何文字から開始か
    vector<int> rem_len(sa.size()), sid(sa.size()), sstart(sa.size());
    for (int j = 0; j < N; ++j) {
        for (int i = lp[j]; i < rp[j]; ++i) {
            rem_len[rank[i]] = rp[j] - i;
            sid[rank[i]] = j;
            sstart[rank[i]] = i - lp[j];
        }
    }
    for (int i = 0; i < lcp.size(); ++i) {
        chmin(lcp[i], rem_len[i]);
        chmin(lcp[i], rem_len[i+1]);
    }

    // lcp をセグ木に乗せる
    SegTree<pll> seg(lcp.size()
                    , [&](pll a, pll b) { return min(a, b); }
                    , pll(1LL<<29, 1LL<<29));
    for (int i = 0; i < lcp.size(); ++i) seg.set(i, pll(lcp[i], i));
    seg.build();

    // suffix tree 上を dfs
    // suffix tree の区間 [left, right) の深さ depth のところを探索
    long long num = 0;
    int query = 0;
    auto dfs = [&](auto self, int left, int right, int depth) -> void {
        // 終端条件
        if (right - left <= 0) return;

        // 区間 [left, right) の共通接尾辞を求める
        auto [nex_depth, mid] = seg.get(left, right-1);
        if (right - left == 1) {
            mid = left;
            nex_depth = rem_len[left];
        }
        long long width = right - left;
        long long add = width * (nex_depth - depth);

        // num 〜 num + add の範囲内にある x を処理していく
        while (query < Q && x[query] <= num + add) {
            long long len = depth + (x[query] - num + width - 1) / width;
            cout << sid[mid]+1 << " " << sstart[mid]+1 << " "
                 << sstart[mid]+len << endl;
            ++query;
        }
        num += add;

        if (right - left == 1 && nex_depth == depth) return;
        
        // 再帰的に処理
        self(self, left, mid+1, nex_depth);
        self(self, mid+1, right, nex_depth);
    };
    dfs(dfs, 0, sa.size(), 0);
}

Cartesian 木で高速化

Suffix Tree は、高さ配列 lcp の Cartesian 木でもある。Cartesian 木の構築は線形時間でできることから、全体の計算量を  O(D + Q) に高速化できる。

#include <bits/stdc++.h>
using namespace std;
using pll = pair<long long, long long>;
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

// Cartesian Tree
template<class T> struct CartesianTree {
    int root;  // root
    vector<int> par, left, right;

    CartesianTree() {}
    CartesianTree(const vector<T>& v) : root(0)
    , par(v.size(), -1), left(v.size(), -1), right(v.size(), -1) {
        vector<int> st(v.size(), 0);
        int top = 0;
        for (int i = 1; i < v.size(); ++i) {
            if (v[st[top]] > v[i]) {
                while (top >= 1 && v[st[top - 1]] > v[i]) --top;
                par[left[i] = st[top]] = i;
                if (top == 0) root = i;
                else right[par[i] = st[top - 1]] = i;
                st[top] = i;
            } else {
                right[par[i] = st[top]] = i;
                st[++top] = i;
            }
        }
    }
};               

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)

    // getter
    int& operator [] (int i) {
        return sa[i];
    }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray() {}
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        vector<int> s;
        for (int i = 0; i < (int)str.size(); ++i) {
            s.push_back(str[i] + 1);
        }
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N - 1, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }

    // debug
    void dump() {
        cout << str << endl;
        for (int i = 0; i < sa.size(); ++i) {
            cout << i << ": " << sa[i] << ", " << str.substr(sa[i]) << endl;
        }
    }
};

void ABC280Ex() {
    // 入力
    int N, Q;
    cin >> N;
    vector<string> S(N);
    string T;  // S を '$' を挟みながら連結したもの
    vector<int> lp(N), rp(N);  // T における S[i] 部分の始点と終点
    for (int i = 0; i < N; ++i) {
        cin >> S[i];
        if (i) T += "$";
        lp[i] = T.size(), T += S[i], rp[i] = T.size();
    }
    cin >> Q;
    vector<long long> x(Q);
    for (int i = 0; i < Q; ++i) cin >> x[i];

    // Suffix Array の構築
    SuffixArray<string> suf(T);
    vector<int> sa = suf.get_sa();
    vector<int> rank = suf.get_rank();
    vector<int> lcp = suf.get_lcp();

    // $ の影響を除くための処理
    // rem_len[i] := sa[i] を表す文字列の先頭から $ までの残り文字数
    // sid[i] := sa[i] を表す文字列に対応する S の id
    // sstart[i] := sa[i] を表す文字列が S[sid] の何文字から開始か
    vector<int> rem_len(sa.size()), sid(sa.size()), sstart(sa.size());
    for (int j = 0; j < N; ++j) {
        for (int i = lp[j]; i < rp[j]; ++i) {
            rem_len[rank[i]] = rp[j] - i;
            sid[rank[i]] = j;
            sstart[rank[i]] = i - lp[j];
        }
    }
    for (int i = 0; i < lcp.size(); ++i) {
        chmin(lcp[i], rem_len[i]);
        chmin(lcp[i], rem_len[i+1]);
    }

    // suffix tree: lcp の Cartesian 木
    CartesianTree<int> ct(lcp);

    // suffix tree 上を dfs
    // suffix tree の区間 [left, right), 最小値 mid の深さ depth のところを探索
    long long num = 0;
    int query = 0;
    auto dfs = [&](auto self, int left, int right, int mid, int depth) -> void {
        // 終端条件
        if (right - left <= 0) return;

        // mid の処理
        if (right - left == 1) mid = left;

        // num 〜 num + add の範囲内にある x を処理していく
        long long width = right - left;
        long long nex_depth = (right - left > 1 ? lcp[mid] : rem_len[left]);
        long long add = width * (nex_depth - depth);
        while (query < Q && x[query] <= num + add) {
            long long len = depth + (x[query] - num + width - 1) / width;
            cout << sid[mid]+1 << " " << sstart[mid]+1 << " "
                 << sstart[mid]+len << endl;
            ++query;
        }
        num += add;
        
        // 再帰的に処理
        if (right - left > 1) {
            self(self, left, mid+1, ct.left[mid], nex_depth);
            self(self, mid+1, right, ct.right[mid], nex_depth);
        }
    };
    dfs(dfs, 0, sa.size(), ct.root, 0);
}

int main() {
    //YosupoJudge();
    ABC280Ex();
}

AtCoder ABC 282 F - Union of Two Sets (青色, 600 点)

コンテスト中、Sparse Table だと思わずに解いていた。コンテスト後に TL で Sparse Table だと見て、「確かに!」と思った。

問題概要

インタラクティブ問題である。次のフェーズ I とフェーズ II に分かれる。

フェーズ I

ジャッジから整数  N が与えられるので、それに応じて、あなたは 1 以上 50000 以下の整数  M を出力する。

さらにあなたは、各  i=1,2,\dots,M について  1 \le l_{i} ​\le r_{i}​ \le N を満たす、 M 個の整数の組  (l_{1}, r_{1}), \dots, (l_{M}, r_{M}) を出力する。

フェーズ II

ジャッジから  Q 回のクエリが与えられる。各クエリでは 2 つの整数  L, R が与えられる。

それに対して、あなたは 1 以上  M 以下の 2 つの整数  a,b を、次の条件を満たすように出力する ( a=b でもよい)。

「集合 { l_{a}​,l_{a}​+1,\dots,r_{a}} と集合 { l_{b}​,l_{b}​+1,\dots,r_{b}} の和集合が、集合 { L, L+1, \dots, R} と一致する。

制約

  •  1 \le N \le 4000
  •  1 \le Q \le 10^{5}
  •  1 \le L \le R \le N

考えたこと

ここでは、Sparse Table をまったく知らなくてもアドホックに解ける思考過程を示す。

まず、この問題はインタラクティブな見た目をしているが、実質的には構築問題だといえる。つまり、整数  N に応じて、次の条件を満たすような区間の集合を求めよ、ということだ。


  • 区間の個数  M は 50000 以下
  • 任意の  1 \le L \le R \le N に対して、 M 個の区間から上手に 2 つを選んで和集合をとると区間  \lbrack L, R \rbrack に一致する

もし  M の大きさに上限がないならば、 \frac{N(N+1)}{2} 個の区間をすべて出力すれば十分だ。実際には、許される区間の個数は、オーダー的には  O(N \log N) 程度だ。なお、求める区間の集合を「解集合」と呼ぶことにする。

順に考えていく

まずは長さが 1 の区間をカバーしていくことを考える。つまり、 L = R という場合だ。こればかりは、すべて解集合に含める必要があるだろう。つまり、

 \lbrack 1, 1 \rbrack, \lbrack 2, 2 \rbrack, \dots, \lbrack N, N \rbrack

をすべて解集合に含めることにする。

これらの解集合から 2 つの区間を選んで和集合をとることによって、 \lbrack 1, 2 \rbrack, \lbrack 2, 3 \rbrack, \lbrack 3, 4 \rbrack \dots といった、長さ 2 の区間もすべて実現できる。そこで今度は、長さが 3 の区間を解集合に含めていくことにする。つまり、

 \lbrack 1, 3 \rbrack, \lbrack 2, 4 \rbrack, \lbrack 3, 5 \rbrack \dots, \lbrack N-2, N \rbrack

を解集合に含めていくことにする。このとき、長さ 6 以下の区間はすべて実現できることになる。たとえば

区間  \lbrack 19, 22 \rbrack は、区間  \lbrack 19, 21 \rbrack区間  \lbrack 20, 22 \rbrack の和集合である

といった具合だ。このノリで、

と続けていくことで、すべての長さの区間が実現できることになる。

区間の長さが指数的に増加していることから、最終的な解集合における区間の個数は  O(N \log N) のオーダーの個数になる。よって大丈夫。

なお、以上の考察は、Sparse Table をまったく知らなくても十分アドホックに導出できると思う。

余談:Sparse Table

Sparse Table を踏まえると、区間の長さの系列を  1, 3, 7, 15, 31, \dots とするよりは、 1, 2, 4, 8, 16, 32, \dots とする方が自然だ。それでももちろん正解になる。

Sparse Table の内容については、公式解説に言及がある。

コード

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;

int main() {
    // フェーズ I
    int N;
    cin >> N;

    vector<pint> res;  // 答え
    map<pint, int> ma;  // 区間 -> 区間番号
    int len = 1;
    while (len <= N) {
        for (int i = 1; i + len - 1 <= N; ++i) {
            res.emplace_back(i, i + len - 1);
            ma[pint(i, i + len - 1)] = res.size();
        }
        len = len * 2 + 1;
    }
    cout << res.size() << endl;
    for (int i = 0; i < res.size(); ++i)
        cout << res[i].first << " " << res[i].second << endl;

    // フェーズ II
    int Q;
    cin >> Q;
    while (Q--) {
        int u, v;
        cin >> u >> v;
        len = 1;
        while (len <= v - u + 1) len = len * 2 + 1;
        len /= 2;
        cout << ma[pint(u, u + len - 1)] << " " << ma[pint(v - len + 1, v)] << endl;
    }
}