けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 2863 Separate String (JAG 模擬地区 2017 H) (500 点)

めちゃくちゃ面白かったし勉強になった!

問題概要

文字列  S が与えられる。それとは別に  N 個の文字列  T_{1}, \dots, T_{N} が与えられる。

文字列  S をいくつかの連続する区間に分割する方法であって、各区間をなす部分文字列が  T_{1}, \dots, T_{N} のいずれかに一致するようなものが何通りあるか、1000000007 で割ったあまりで求めよ。

制約

  •  1 \le |S|, N \le 10^{5}
  •  |T_{i}| の総和  \le 2 \times 10^{5}

考えたこと

まずこの手の区間分割処理は次のような DP で扱える。

  • dp[ i ] := S の最初の i 文字を、条件を満たすようにいくつかの区間に分割する方法の数

そして

dp[ j ] += dp[ i ] (区間 [i, j) が条件を満たすとき)

という風な遷移で実現できる。しかしナイーブに処理しては、DP 遷移だけで  O(N^{2})、それぞれの文字列パターンマッチングに  O(\sum_{i}|T_{i}|) を要するので全体で  O(N^{2}\sum_{i}|T_{i}|) の計算量となってしまう。なんとかして削減したい。

文字列検索処理

まずは文字列検索パートはいくらでも削減できる。

  • ローリングハッシュ
  • trie
  • Aho-Corasick

などなど、工夫できる道具がたくさんあるのだ。これらを用いることで、部分文字列が辞書に含まれるかどうかの判定は高速化できるものとして考える。

DP の遷移先が実は限られる

では DP の方はどうだろうか...この手の DP 高速化は

  • 累積和
  • いもす法
  • セグメント木
  • オンライン・オフライン変換 (+ in-place 化)
  • Convex Hull Trick
  • Monotone Minima

などの方法で高速化できることが多い。でも今回はこれらのいずれも上手く行かなそうだ。しかしここで注目したいことは、 M = \sum_{i}|T_{i}| として、

 |T_{1}|, \dots, |T_{N}| としてありうる値は、 O(\sqrt{M}) 通りしかない」

ということだ。一般に「総和が  M となる正の整数列に含まれる値の種類数は  O(\sqrt{M}) で抑えられることがいえる。仕組みは簡単で、

  • 値が  \sqrt{M} 以下のものは、その定義からそもそも  O(\sqrt{M}) 通りしかない
  • 値が  \sqrt{M} 以上のものは、全体の総和が  M になることから  O(\sqrt{M}) 個以下しかとれない

という具合だ。よって、

「DP の遷移先としては  O(\sqrt{M}) 通りしか発生しない」

ということがいえた。ある程度頑張って高速化すると通せる。

コード

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

struct RollingHash {
    static const int base1 = 1007;
    static const int mod1 = 1000000007;
    vector<long long> hash1, power1;

    // construct
    RollingHash(const string &S) {
        int n = (int)S.size();
        hash1.assign(n+1, 0);
        power1.assign(n+1, 1);
        for (int i = 0; i < n; ++i) {
            hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
            power1[i+1] = (power1[i] * base1) % mod1;
        }
    }
    
    // get hash value of S[left:right]
    inline long long get(int l, int r) const {
        long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
        if (res1 < 0) res1 += mod1;
        return res1;
    }

    // get hash value of whole S
    inline long long get() const {
        return hash1.back();
    }

    // get lcp of S[a:] and S[b:]
    inline int getLCP(int a, int b) const {
        int len = min((int)hash1.size()-a, (int)hash1.size()-b);
        int low = 0, high = len;
        while (high - low > 1) {
            int mid = (low + high) >> 1;
            if (get(a, a+mid) != get(b, b+mid)) high = mid;
            else low = mid;
        }
        return low;
    }

    // get lcp of S[a:] and T[b:]
    inline int getLCP(const RollingHash &T, int a, int b) const {
        int len = min((int)hash1.size()-a, (int)hash1.size()-b);
        int low = 0, high = len;
        while (high - low > 1) {
            int mid = (low + high) >> 1;
            if (get(a, a+mid) != T.get(b, b+mid)) high = mid;
            else low = mid;
        }
        return low;
    }
};

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    int N; 
    cin >> N;
    vector<set<long long>> vhss(210000);
    for (int i = 0; i < N; ++i) {
        string T; cin >> T;
        vhss[(int)T.size()].insert(RollingHash(T).get());
    }
    vector<pair<int,set<long long>>> hss;
    for (int i = 0; i < 210000; ++i) {
        if (vhss[i].empty()) continue;
        hss.emplace_back(i, vhss[i]);
    }
    string S; 
    cin >> S;
    RollingHash rh(S);
    vector<mint> dp(S.size()+1, 0);
    dp[0] = 1;
    for (int i = 0; i < S.size(); ++i) {
        if (dp[i] == 0) continue;
        for (const auto &it : hss) {
            int len = it.first;
            if (i+len > S.size()) break;
            if (it.second.count(rh.get(i, i+len))) dp[i+len] += dp[i];
        }
    }
    cout << dp[S.size()] << endl;
}