文字列検索に関するライブラリが充実していれば怖いものがない。でも文字列のことを知らなくても実は DP でも解ける!!!
- Suffix Array
- Z-algorithm (editorial 解)
- ロリハ + 二分探索
- 「ロリハ + 二分探索」の高速化 (editorial のラスト 3 行で言及された別解)
- DP
の五通りの方法でやってみる
問題へのリンク
問題概要
長さ の文字列 があたえれる。
の連続する部分文字列として、重ならずに 2 回以上現れるもののうち、最長のものの長さを答えてください。
制約
解法 1:Suffix Array の LCP 配列
まず、蟻本の P.340 に書いてある方法。
ここで文字列 S の i 文字目から先を取り出した部分文字列を S[ i : ] と書くことにする。
さて、Suffix Array で何ができるのかだけ書くと
- で Suffix Array を構成する ( の方法もある)
- それによって求められた LCP 配列と呼ばれるものを Sparse Table に載せる ( かかる)
- 以上の前処理を行っておくと、S[ i : ] と S[ j : ] とが、先頭から何文字まで共通しているか (LCP とよぶ) を、 で判定することができる
ということになる (ここまですべて蟻本参照、ただし蟻本では Sparse Table のところをセグメントツリーでやっている)。この方法を使えば
- 各 i < j に対して、
- LCP の値 lcp を求めて
- そのままだと部分文字列がかぶることもある ("ababa" で i = 0, j = 2 のとき LCP は "aba" で長さ 3) ので、lcp が j - i と比べて大きかったら lcp = j - i とする
というのを各 i < j で全探索して最大値を求めれば OK。計算量は 。
以下のコードで、main 関数はとても短くて「ライブラリがあれば貼るだけ」を強く実感できる!!!
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
template<class MeetSemiLattice> struct SparseTable {
vector<vector<MeetSemiLattice> > dat;
vector<int> height;
SparseTable() { }
SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
void init(const vector<MeetSemiLattice> &vec) {
int n = (int)vec.size(), h = 0;
while ((1<<h) < n) ++h;
dat.assign(h, vector<MeetSemiLattice>(1<<h));
height.assign(n+1, 0);
for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
for (int i = 1; i < h; ++i)
for (int j = 0; j < n; ++j)
dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
}
MeetSemiLattice get(int a, int b) {
return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
}
};
struct SuffixArray {
string str;
vector<int> sa;
vector<int> lcp;
inline int& operator [] (int i) {return sa[i];}
SuffixArray(const string& str_) : str(str_) { buildSA(); calcLCP(); }
void init(const string& str_) { str = str_; buildSA(); calcLCP(); }
vector<int> rank_sa, tmp_rank_sa;
struct CompareSA {
int n, k;
const vector<int> &rank;
CompareSA(int n, int k, const vector<int> &rank_sa) : n(n), k(k), rank(rank_sa) {}
bool operator()(int i, int j) {
if (rank[i] != rank[j]) return (rank[i] < rank[j]);
else {
int rank_ik = (i + k <= n ? rank[i + k] : -1);
int rank_jk = (j + k <= n ? rank[j + k] : -1);
return (rank_ik < rank_jk);
}
}
};
void buildSA() {
int n = (int)str.size();
sa.resize(n+1), lcp.resize(n+1), rank_sa.resize(n+1), tmp_rank_sa.resize(n+1);
for (int i = 0; i < n; ++i) sa[i] = i, rank_sa[i] = (int)str[i];
sa[n] = n, rank_sa[n] = -1;
for (int k = 1; k <= n; k *= 2) {
CompareSA csa(n, k, rank_sa);
sort(sa.begin(), sa.end(), csa);
tmp_rank_sa[sa[0]] = 0;
for (int i = 1; i <= n; ++i) {
tmp_rank_sa[sa[i]] = tmp_rank_sa[sa[i - 1]];
if (csa(sa[i - 1], sa[i])) ++tmp_rank_sa[sa[i]];
}
for (int i = 0; i <= n; ++i) rank_sa[i] = tmp_rank_sa[i];
}
}
vector<int> rsa;
SparseTable<int> st;
void calcLCP() {
int n = (int)str.size();
rsa.resize(n+1);
for (int i = 0; i <= n; ++i) rsa[sa[i]] = i;
lcp.resize(n+1);
lcp[0] = 0;
int cur = 0;
for (int i = 0; i < n; ++i) {
int pi = sa[rsa[i] - 1];
if (cur > 0) --cur;
for (; pi + cur < n && i + cur < n; ++cur) {
if (str[pi + cur] != str[i + cur]) break;
}
lcp[rsa[i] - 1] = cur;
}
st.init(lcp);
}
int getLCP(int a, int b) {
return st.get(min(rsa[a], rsa[b]), max(rsa[a], rsa[b]));
}
};
int main() {
int N;
string S;
cin >> N >> S;
SuffixArray SA(S);
int res = 0;
for (int i = 0; i < N; ++i) {
for (int j = i+1; j < N; ++j) {
int lcp = SA.getLCP(i, j);
lcp = min(lcp, j-i);
res = max(res, lcp);
}
}
cout << res << endl;
}
解法 2:Z-algorithm
ほとんど同じ思想の解法で、Z-algorithm は
- の前処理で配列 lcp を求める
- lcp[ i ] := S 自身と S[ i : ] とが、先頭から最大で何文字一致しているかを で求めることができる
というもの。Suffix Array の LCP 配列は LCP(i, j) を求めることができたのに対して、Z-algorithm では LCP(0, i) の形しか求めることができない。しかしその対策は簡単で
- 各 i に対して の i 文字目以降の部分文字列を T とする
- T について Z-algorithm を適用する
という風にすれば OK。計算量はやはり 。
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
vector<int> Zalgo(const string &S) {
int N = (int)S.size();
vector<int> res(N);
res[0] = N;
int i = 1, j = 0;
while (i < N) {
while (i+j < N && S[j] == S[i+j]) ++j;
res[i] = j;
if (j == 0) {++i; continue;}
int k = 1;
while (i+k < N && k+res[k] < j) res[i+k] = res[k], ++k;
i += k, j -= k;
}
return res;
}
int main() {
int N;
string S;
cin >> N >> S;
int res = 0;
for (int i = 0; i < N; ++i) {
string T = S.substr(i);
auto lcp = Zalgo(T);
for (int j = 0; j < T.size(); ++j) {
int l = min(lcp[j], j);
res = max(res, l);
}
}
cout << res << endl;
}
解法 3:ロリハ + 二分探索
実は LCP はローリングハッシュ + 二分探索でも求められるという話は結構有名だったりする。ローリングハッシュは
- の前処理を行っておくことで
- 文字列 S の区間 [i, j) の部分文字列に関するハッシュ値を で返す
というもの。発想は累積和とものすごく近い!!!
さて、これができれば
- 各 i < j に対して
- 「i 文字目から始めて x 文字分とったもの」と「j 文字目から始めて x 文字分とったもの」のハッシュ値が一致するような最大の x を求める (by 二分探索)
という方法によって、 で LCP(i, j) を求めることができる。計算量は となる。このままではかなり定数倍が厳しい。以下のコードは 1735ms だった。
#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
struct RollingHash {
static const int base1 = 1007, base2 = 2009;
static const int mod1 = 1000000007, mod2 = 1000000009;
vector<long long> hash1, hash2, power1, power2;
RollingHash(const string &S) {
int n = (int)S.size();
hash1.assign(n+1, 0);
hash2.assign(n+1, 0);
power1.assign(n+1, 1);
power2.assign(n+1, 1);
for (int i = 0; i < n; ++i) {
hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
hash2[i+1] = (hash2[i] * base2 + S[i]) % mod2;
power1[i+1] = (power1[i] * base1) % mod1;
power2[i+1] = (power2[i] * base2) % mod2;
}
}
inline pair<long long, long long> get(int l, int r) const {
long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
if (res1 < 0) res1 += mod1;
long long res2 = hash2[r] - hash2[l] * power2[r-l] % mod2;
if (res2 < 0) res2 += mod2;
return {res1, res2};
}
inline int getLCP(int a, int b) const {
int len = min((int)hash1.size()-a, (int)hash1.size()-b);
int low = 0, high = len;
while (high - low > 1) {
int mid = (low + high) >> 1;
if (get(a, a+mid) != get(b, b+mid)) high = mid;
else low = mid;
}
return low;
}
};
int main() {
int N;
string S;
cin >> N >> S;
RollingHash rh(S);
int res = 0;
for (int i = 0; i < N; ++i) {
for (int j = i+1; j < N; ++j) {
int lcp = rh.getLCP(i, j);
lcp = min(lcp, j-i);
res = max(res, lcp);
}
}
cout << res << endl;
}
解法 4:ロリハにぶたん解の高速化
ロリハ二分探索解は単純にやると定数倍がかなり辛いので、ちょっとした高速化ができる。
- 各 (i, j) に対して二分探索して LCP を求める
という風にする代わりに
- 長さ m を予め固定して、S.substr(i, m) == S.substr(j, m) となるような (i, j) が存在するかどうかを判定して二分探索
という感じにする。各 i に対して区間 [i, i + m) のハッシュ値を求めて、それらに重複があるかどうか (ただしそれらの index が m 以上離れている必要がある) を判定するだけになるので、かなり早くなる。 程度。この工夫で、1735ms -> 13 ms になった。
#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <algorithm>
using namespace std;
struct RollingHash {
static const int base1 = 1007, base2 = 2009;
static const int mod1 = 1000000007, mod2 = 1000000009;
vector<long long> hash1, hash2, power1, power2;
RollingHash(const string &S) {
int n = (int)S.size();
hash1.assign(n+1, 0);
hash2.assign(n+1, 0);
power1.assign(n+1, 1);
power2.assign(n+1, 1);
for (int i = 0; i < n; ++i) {
hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
hash2[i+1] = (hash2[i] * base2 + S[i]) % mod2;
power1[i+1] = (power1[i] * base1) % mod1;
power2[i+1] = (power2[i] * base2) % mod2;
}
}
inline pair<long long, long long> get(int l, int r) const {
long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
if (res1 < 0) res1 += mod1;
long long res2 = hash2[r] - hash2[l] * power2[r-l] % mod2;
if (res2 < 0) res2 += mod2;
return {res1, res2};
}
};
int main() {
int N;
string S;
cin >> N >> S;
RollingHash rh(S);
auto check = [&](int d) -> bool {
map<pair<long long, long long>, int> ma;
for (int i = 0; i + d <= N; ++i) {
auto p = rh.get(i, i+d);
if (ma.count(p)) {
if (i - ma[p] >= d) return true;
}
else ma[p] = i;
}
return false;
};
int left = 0, right = N/2 + 1;
while (right - left > 1) {
int mid = (left + right) >> 1;
if (check(mid)) left = mid;
else right = mid;
}
cout << left << endl;
}
解法 5:DP
最後に、実は高度な文字列検索アルゴリズムを持っていなくても解けると。
- dp[ i ][ j ] := i 文字目からと j 文字目からとで最長の長さ (この時点では文字列がかぶっても OK とする)
とすると
- S[ i ] != S[ j ] なら dp[ i ][ j ] = 0
- そうでないなら dp[ i ][ j ] = dp[ i + 1 ][ j + 1 ] + 1
という感じ。更新順に注意する。最後は min(dp[ i ][ j ], j - i) の最大値を出力する。
#include <iostream>
#include <vector>
#include <string>
using namespace std;
void chmax(int &a, int b) { if (a < b) a = b; }
int main() {
int N; string S; cin >> N >> S;
int res = 0;
vector<vector<int> > dp(N+1, vector<int>(N+1, 0));
for (int i = N-1; i >= 0; --i) {
for (int j = N-1; j > i; --j) {
if (S[i] == S[j]) chmax(dp[i][j], dp[i+1][j+1] + 1);
chmax(res, min(dp[i][j], j-i));
}
}
cout << res << endl;
}