けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 282 D - Make Bipartite 2 (緑色, 400 点)

一般にグラフの問題を解くときは「連結成分ごとに解けば良いのではないか」と考えるのが有効なことがある!

その意識がしっかりしていれば、「グラフが非連結の場合に気づかなかった」という罠を回避できる!!

問題概要

頂点数  N、辺数  M の単純な無向グラフが与えられます。頂点番号を  1, 2, \dots, N とします。

このグラフにおいて、以下の条件を満たす頂点の組  (u, v) ( 1 \le u \lt v \le N) の個数を答えてください。

  • 頂点  u, v を結ぶ辺が存在しない
  • 頂点  u, v を結ぶ辺を追加したグラフは二部グラフである

制約

  •  2 \le N \le 2 \times 10^{5}
  •  0 \le M \le 2 \times 10^{5}

考えたこと

二部グラフとは、下図のように、

  • 白色の頂点同士が隣接しない
  • 黒色の頂点同士が隣接しない

という条件を満たすように、各頂点を白黒に塗り分けられるようなグラフ。

二部グラフ判定はたとえば、けんちょん本などに書いた。他にも次の記事などにも書いた。

qiita.com

平たく言えば、二部グラフの判定方法は次のような感じになる。


  • グラフの連結成分のうち、まだ考えていない連結成分について、1 つの頂点を選んで「白色」に塗る (黒色でもよい)
  • その頂点とつながる頂点に対しては、次々と連鎖的に色が決まっていく
  • その決まった色が整合性を取れない場合は二部グラフではない
  • 以上の作業を、連結成分がなくなるまで繰り返す

 

連結成分ごとに考える

冒頭に書いた通り、「グラフの問題を連結成分ごとに考える問題へと帰着する」という方針で考える。

まず、連結成分の中に二部グラフでないグラフがあった時点で、グラフ全体も二部グラフではない。以降は、グラフが二部グラフであるものとして考えることにする。

さて、この問題で追加する辺を次の 2 つの場合に分けて考える。

  1. 同じ連結成分同士の 2 頂点を結ぶ辺
  2. 異なる連結成分間の 2 頂点を結ぶ辺

1. 同じ連結成分同士の 2 頂点を結ぶ辺

同じ連結成分内部の辺については次のように考えられる。二部グラフ判定したときの

  • 白色の頂点の個数を  a
  • 黒色の頂点の個数を  b
  • すでにある辺の本数を  m

としたときに、 ab - m 本の頂点組には辺を追加できる。なお、後述するように、実際には「連結成分ごとに  m を求める」という作業は必要ない。

ここで気になるのは、連結成分の色の塗り方によっては、 a, b の値が変わってしまうのではないかということだ。

しかし実はその心配はない。二部グラフ判定においては「1 つの頂点の色を決めると残りの頂点の色も自動的に決まる」ということを利用する。

よって、「白色: b 個」「黒色  a 個」というように  a, b が入れ替わることはあっても、値が変わることはないのだ。

2. 異なる連結成分間の 2 頂点を結ぶ辺

実は、二部グラフと二部グラフの間をどう結んでも、二部グラフのままなのだ。

たとえ同じ色同士を結んでしまったとしても、片方の連結成分の色を反転させればよい。

 

まとめ

以上の考察をまとめよう。結べる頂点組の個数を数えるよりも、結べない頂点組の個数を数える方が楽そうなので、ここではそうする (どっちでも正解できる)。

結んではいけない辺は「同じ連結成分内の、同じ色の頂点同士」ということがわかる。

具体的には、同じ連結成分内の白色頂点の個数を  a、黒色頂点の個数を  b としたとき、結んではいけない辺の本数は

 \displaystyle \frac{a(a-1)}{2} + \frac{b(b-1)}{2}

と表せる。これを各連結成分について合算すればよい。

全体として、計算量は  O(N + M) となる。

コード

ここでは BFS で実装する。

Python3

import queue

# N  * (N - 1) // 2
def com(N):
    return N * (N - 1) // 2

# 色
WHITE, BLACK = 0, 1

# 入力
N, M = map(int, input().split())
edges = [list(map(int, input().split())) for _ in range(M)]
G = [[] for _ in range(N)]
for e in edges:
    u, v = e[0]-1, e[1]-1
    G[u].append(v)
    G[v].append(u)

# 各変数
num_ng_edges = 0
is_bipartite = True
color = [-1] * N

# 各連結成分を走査する
for s in range(N):
    if color[s] != -1: continue

    # 白色頂点、黒色頂点の個数
    white_num, black_num = 0, 0

    # 頂点 s を始点とする BFS
    que = queue.Queue()
    que.put(s)
    color[s] = WHITE
    while not que.empty():
        v = que.get()

        # 頂点の個数をカウント
        if color[v] == WHITE: white_num += 1
        else: black_num += 1

        # 頂点 v の隣接頂点について
        for v2 in G[v]:
            if color[v2] != -1:
                # すでに色が塗られているとき
                if color[v2] == color[v]:
                    # 隣接頂点が同色はダメ
                    is_bipartite = False
            else:
                # 新しい頂点を処理
                color[v2] = 1 - color[v]
                que.put(v2)

    # ng 辺の本数
    num_ng_edges += com(white_num) + com(black_num)

# 答え
print(com(N) - M - num_ng_edges if is_bipartite else 0)

C++

#include <bits/stdc++.h>
using namespace std;
using Graph = vector<vector<int>>;  // グラフのデータ構造

// 色
const int WHITE = 0;
const int BLACK = 1;

int main() {
    // 入力
    long long N, M;
    cin >> N >> M;
    Graph G(N);
    for (int i = 0; i < M; ++i) {
        int a, b;
        cin >> a >> b;
        --a, --b;
        G[a].push_back(b);
        G[b].push_back(a);
    }

    long long num_ng_edges = 0;
    bool is_bipartite = true;  // 二部グラフかどうか
    vector<int> color(N, -1);  // 各頂点の色

    for (int s = 0; s < N; ++s) {
        if (color[s] != -1) continue;

        // 白色頂点、黒色頂点の個数
        long long white_num = 0, black_num = 0;

        // 頂点 s を始点とする BFS
        queue<int> que;
        que.push(s);
        color[s] = WHITE;  // 頂点 s を白色に塗る
        while (!que.empty()) {
            int v = que.front();
            que.pop();

            // 頂点の個数をカウント
            if (color[v] == WHITE) 
                ++white_num;
            else
                ++black_num;

            // 頂点 v の隣接頂点について
            for (auto v2 : G[v]) {
                if (color[v2] != -1) {
                    // すでに色が塗られているとき
                    if (color[v2] == color[v]) {
                        // 隣接頂点が同色はダメ
                        is_bipartite = false;
                    }
                } else {
                    // 新しい頂点を処理
                    color[v2] = 1 - color[v];
                    que.push(v2);
                }
            }
        }

        // ng 辺の本数
        num_ng_edges += white_num * (white_num - 1) / 2
                        + black_num * (black_num - 1) / 2;
    }

    // 答え
    if (is_bipartite)
        cout << N * (N - 1) / 2 - M - num_ng_edges << endl;
    else
        cout << 0 << endl;
}

AtCoder ABC 282 C - String Delimiter (灰色, 300 点)

フラグの考え方で解くのが一番分かりやすいと思った

問題概要

英小文字と、文字 ," からなる長さ  N の文字列  S が与えられます。

 S に含まれる文字 " の個数は偶数であることが保証されています。

 S に含まれる " の個数を  2K 個とすると、各  i=1,2,\dots,K について  2i−1 番目の " から  2i 番目の " までの文字のことを「括られた文字」と呼びます。

あなたの仕事は、文字列  S に含まれる , のうち、括られた文字でないもの を . で置き換えて得られる文字列を答えることです。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

色んな考え方がありそう。ここでは、for 文を用いて処理していく方法を考える。

for 文中の各文字に対して「今括られているのかどうか」を表すフラグ変数を用意する


is_inclusion ← 今括られているかどうかを表す値 (括られている:1、括られていない:0)


そして、次のような for 文で処理していけばよい。

int is_inclusion = 0;
for (int i = 0; i < N; ++i) {
    if (S[i] == '"') {
        // 括られているかどうかが変化する
        is_inclusion = 1 - is_inclusion  
    } else if (S[i] == ',' && !is_inclusion) {
        S[i] = '.';
    }
}

計算量は  O(N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    string S;
    cin >> N >> S;

    int is_inclusion = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i] == '"') {
            // 括られているかどうかが変化する
            is_inclusion = 1 - is_inclusion;
        } else if (S[i] == ',' && !is_inclusion) {
            S[i] = '.';
        }
    }
    cout << S << endl;
}

AtCoder ARC 050 C - LCM 111 (試験管黄色)

レピュニット数を題材とした問題。

問題概要

 1 A 個並べた数と、 1 B 個並べた数の最小公倍数を  M で割ったあまりを求めよ。

制約

  •  1 \le A, B \le 10^{18}
  •  1 \le M \le 10^{9}

考えたこと

 1 を並べた数をレピュニット数とよぶ。数学的に面白い性質が色々ある。

レピュニット数の最大公約数は、実は  1 \mathrm{GCD}(A, B) 個並べた数になる。このことを Euclid の互助法に基づいて検証する。

たとえば  111111 ( 1 6 個) と、 1111 ( 1 4 個) の最大公約数を求めたいとする。

 111111 - 1111 = 110000

であり、 10 11...1 は互いに素なので、

 \mathrm{GCD}(111111, 1111) = \mathrm{GCD}(11, 1111)

になると言える。一般に、 A \gt B のとき、

 \mathrm{GCD}(1 が A 個, 1 が B 個) = \mathrm{GCD}(1 が A - B 個, 1 が B 個)

が成立する。この  A, B の部分の変化が Euclid の互助法における変化に他ならない。よって、上に述べたことが成立する。

最小公倍数を求める

 1 A 個並べた数を  f(A) と表すことにすると、求めたい数は  G = \mathrm{GCD}(A, B) として、

 \displaystyle \frac{f(A)f(B)}{f(G)} \pmod m

ということになる。 m が素数とは限らないので、 f(G) で割るという演算は回避したい。そこで、

  •  f(A) = 1 + 10 + 10^{2} + \dots + 10^{A-1}
  •  \displaystyle \frac{f(B)}{f(G)} = 1 + 10^{G} + 10^{2G} + \dots + 10^{(\frac{B}{G}-1)G}

であることに着目してこれを求めることにする。

全体として、計算量は  O(\log A + \log B) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// a^n mod m
long long pow(long long a, long long n, long long m) {
    if (n == 0) return 1;
    long long t = pow(a, n/2, m);
    t = t * t % m;
    if (n & 1) t = t * a % m;
    return t;
}

// 1 + a + ... + a^{n-1} mod m
long long ser(long long a, long long n, long long m) {
    if (n == 0) return 0;
    long long res = ser(a * a % m, n/2, m);
    res = res * (a + 1) % m;
    if (n & 1) res = (res * a + 1) % m;
    return res;
}

int main() {
    long long A, B, M;
    cin >> A >> B >> M;

    long long G = gcd(A, B);
    long long fA = ser(10, A, M);
    long long fBdivfG = ser(pow(10, G, M), B/G, M);
    cout << fA * fBdivfG % M << endl;
}

CS Academy 073 DIV2 E - Strange Substring

文字列の連続した部分文字列を数え上げるのは Suffix Array の典型問題。それを少し応用した面白い問題!

問題概要

英小文字からなる 2 つの文字列  A, B が与えられる。次の条件を満たす文字列の個数を求めよ。

  •  A 中に連続した部分文字列として含まれる
  •  B 中に連続した部分文字列として含まれない

制約

  •  1 \le |A|, |B| \le 10^{5}

考えたこと

前提となる問題がこちら。

drken1215.hatenablog.com

さて、2 つの文字列にまたがる文字列検索について考える問題では、

 S =  A + "?" +  B

というように、 A, B を連結させた文字列について Suffix Array を求めるのは常套手段。

たとえば、 A = "abcab"、 B = "bcab" のとき、接尾辞配列は次のようになる。

0: 10, 
1: 5, ?bcab
2: 8, ab
3: 3, ab?bcab
4: 0, abcab?bcab
5: 9, b
6: 4, b?bcab
7: 6, bcab
8: 1, bcab?bcab
9: 7, cab
10: 2, cab?bcab

このうち、"?" を含む suffix たち (それ以外は文字列  B に関するものなのでスキップ) について、次のルールで集計していくことにする。


今考えている suffix の prefix として考えられる文字列のうち

  1. "?" を含むものは除外する
  2. Suffix Array のより後方にある suffix の prefix となり得るものは除外する
  3. Suffix Array のより前方にある suffix のうち、index が B 側にあるものの prefix となり得るものも除外する

具体的には、

  • 1 を考慮すると、個数は add = N - sa[i] と表せる
  • 2, 3 をそれぞれ考慮して、その最大値を sub として、add から sub を引くことにする
    • 2 を考慮すると、chmax(sub, lcp[i]) と表せる
    • 3 を考慮すると、Suffix Array において直前の B 由来の index を prev として、chmax(sub, lcp(sa[i], prev)) と表せる

そうして、add - sub の値の総和をとったものが答えとなる。

全体として、 O(N \log N) の計算量で求められる。

コード

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
using pll = pair<long long, long long>;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

#define REP(i, n) for (long long i = 0; i < (long long)(n); ++i)
#define REP2(i, a, b) for (long long i = a; i < (long long)(b); ++i)
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
template<class T> ostream& operator << (ostream &s, vector<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, deque<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, vector<vector<T> > P)
{ for (int i = 0; i < P.size(); ++i) { s << endl << P[i]; } return s << endl; }
template<class T> ostream& operator << (ostream &s, set<T> P)
{ for(auto it : P) { s << "<" << it << "> "; } return s; }
template<class T1, class T2> ostream& operator << (ostream &s, map<T1,T2> P)
{ for(auto it : P) { s << "<" << it.first << "->" << it.second << "> "; } return s; }

// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
    void init(const vector<MeetSemiLattice> &vec) {
        int n = (int)vec.size(), h = 0;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    SparseTable<int> st;  // use for calcultating lcp(i, j)

    // getter
    int& operator [] (int i) {
        return sa[i];
    }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        vector<int> s;
        for (int i = 0; i < (int)str.size(); ++i) {
            s.push_back(str[i] + 1);
        }
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
        buildSparseTable();
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }
    
    // build sparse table for calculating lcp
    void buildSparseTable() {
        st.init(lcp);
    }

    // calc lcp of str.sutstr(a) and str.substr(b)
    int getLCP(int a, int b) {
        return st.get(min(rank[a], rank[b]), max(rank[a], rank[b]));
    }
};

int main() {
    string A, B;
    cin >> A >> B;
    int N = A.size(), M = B.size();

    // Suffix Array の構築
    string S = A + "?" + B;
    SuffixArray<string> suf(S);
    vector<int> sa = suf.get_sa();
    vector<int> lcp = suf.get_lcp();
    
    // 集計
    long long res = 0;
    int prev = -1;
    for (int i = 0; i < sa.size(); ++i) {
        // B 側で始まる場合は、その index を記録しておいてスキップ
        if (sa[i] > N) {
            prev = sa[i];
            continue;
        }
        int add = N - sa[i];
        int sub = 0;
        if (i < lcp.size()) chmax(sub, lcp[i]);
        if (prev > N) chmax(sub, suf.getLCP(sa[i], prev));
        res += add - sub;
    }
    cout << res << endl;
}

square869120Contest #2 E - 部分文字列

Suffix Array の典型問題

問題概要

英小文字のみからなる長さ  N の文字列  S が与えられる。

 S の連続する部分文字列として登場しうる文字列をすべて考える。重複は除外する。これらの文字列の長さの総和を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

次の問題とほとんど同じ!

drken1215.hatenablog.com

このリンク先の問題は、部分文字列 (重複は除外) の個数を答える問題だった。今回は「個数」ではなく「長さの総和」を答える。

とは言え、ほとんど同じ考え方で解ける。

たとえば、文字列  S = poporinri を例にとる。この文字列の Suffix Array や高さ配列 LCP を求めると次のようになる。

  • 接尾辞配列:sa = [9, 8, 5, 6, 1, 3, 0, 2, 7, 4]
  • 高さ配列:lcp = [0, 1, 0, 0, 1, 0, 2, 0, 2]

具体的には次の通りである。

0: (9 文字目以降、空文字列)
1:i (8 文字目以降)
2:inri (5 文字目以降)
3:nri (6 文字目以降)
4:oporinri (1 文字目以降)
5:orinri (3 文字目以降)
6:poporinri (0 文字目以降)
7:porinri (2 文字目以降)
8:ri (7 文字目以降)
9:rinri (4 文字目以降)

この情報をもとにして、部分文字列をすべて走査する方法を考えよう。上記の suffix たちに対して、prefix をとってできる文字列の集合から、重複を除去したものが答えとなる。

重複を除去するためには、高さ配列 lcp が役に立つ。たとえば、sa[6] を表す文字列 "poporinri" と sa[7] を表す文字列 "porinri" は、lcp 値は 2 である (下表参照)。これは、先頭から 2 文字分が重複していることを意味する。こうした重複を除去していけばよい。

Suffix Array での順序 suffix 文字列 i 番目と i+1 番目の重複分の長さ lcp[i]
0 0
1 i 1
2 inri 0
3 nri 0
4 oporinri 1
5 orinri 0
6 poporinri 2
7 porinri 0
8 ri 2
9 rinri 0

重複を除去するために、「Suffix Array の順序で後ろの方の文字列の prefix にも登場する文字列は除外して数えていく」というようにする。つまり、各  i (=0, 1, \dots, N) に対して、

  • sa[i] を表す文字列 (長さを  m とする) の prefix となる文字列の長さの総和: \displaystyle \frac{m(m+1)}{2} から、 lcp[i] によって表される部分の文字列 (長さを  m とする) の prefix となる文字列の長さの総和: \displaystyle \frac{l(l+1)}{2}

を引いた値を求め、その総和をとればよい。

全体として  O(N) の計算量で求められる。

コード

#include <bits/stdc++.h>
using namespace std;

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;   // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> lcp;  // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    int& operator [] (int i) {
        return sa[i];
    }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        int N = (int)str.size();
        vector<int> s;
        for (int i = 0; i < N; ++i) s.push_back(str[i] + 1);
        s.push_back(0);
        sa = sa_is(s);
        calcLCP(s);
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // prepair lcp
    vector<int> rsa;
    void calcLCP(const vector<int> &s) {
        int N = (int)s.size();
        rsa.assign(N, 0), lcp.assign(N, 0);
        for (int i = 0; i < N; ++i) rsa[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rsa[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rsa[i] - 1] = h;
        }
    }
};

int main() {
    string S;
    cin >> S;
    long long N = S.size();

    // Suffix Array の構築
    SuffixArray<string> sa(S);
    vector<int> lcp = sa.lcp;

    auto calc = [&](long long n) -> long long {
        return n * (n + 1) / 2;
    };

    // 集計していく
    long long res = 0;
    for (int i = 0; i <= N; ++i) {
        res += calc(N - sa[i]);
        if (i < N) res -= calc(lcp[i]);
    }

    // 答え
    cout << res << endl;
}