けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

TopCoder SRM 402 DIV1 Hard - IncreasingSequence (本番 2 人)

詰め切るの大変だった!

問題概要

'0'〜'9' からなる長さ  N の文字列  S が与えられる。

これらの文字列をいくつかの連続する部分文字列に分ける。次の条件を満たす必要がある。

  • 各部分文字列を数値とみなしたとき、strictly に単調増加である
  • leading zero は許容する

このような分け方のうち、末尾の数値が最小となるものを求めよ。複数通りある場合は、辞書順最大のものを求めよ。そして、求められた分け方について、区切られた数値の積を 1000000003 で割った余りを答えよ。

制約

  •  1 \le N \le 2500

考えたこと

ぱっと見は DP で一瞬に見えた。一旦、計算量を無視して考えることにする。


dp[i] ← 先頭から  i 文字を条件を満たすように区切る方法のうち、末尾を表す数値の最大値


そうして、いつもの DP をすれば良さそうに見える。

for (int i = 1; i <= N; ++i) {
    for (int j = 0; j < i; ++j) {
        if (dp[j] < S[j:i] の表す数) {
            chmin(dp[i], S[j:i] の表す数);
        }
    }
}

という感じだ。また、末尾が決まってからは今度は逆方向に同様の DP を回すことで、辞書順最大のものを求めることができる。計算量を無視すれば解けた。

しかし、dp[i] が最長  N 文字の文字列なので、計算量は  O(N^{3}) となる。このままでは間に合わない。

S[i:j] の序列を求める

僕のとった方法は、文字列の各区間  \lbrack l, r) を小さい順にソートすることだった。

具体的には、各区間  \lbrack l, r) に対して、区間の表す数値の大小関係に対応するように、整数値  f(l, r) を定めてあげることとした。僕は次のようにやった。


  1. leading zero は取り除く (適切な前処理によって  O(1) でできる)
  2. 桁数を  d とする
  3. Suffix Array における、区間  \lbrack l, r) (から leading zero を取り除いた部分) を表す文字列の lower_bound を求めて、 k とする
  4.  f(l, r) = 10000d + k とする

3 について、最初は  k を、Suffix Array のランクを rank として、rank[l] としていた。しかしこれだと、

 S = "242456"

などのケースで、区間 [0, 2) の "24" と、区間 [2, 4) の "24" に異なる値がついてしまう。よって、lower_bound を求めることにした。

先ほどの DP に対して、この  f の値を用いることで、計算量は  O(N^{2} \log N) へと改善できた。

leading zero についての注意

0 がたくさんある場合に注意。たとえば

 S = "10000000000010"

の場合、正解は "1", "0000000000010" と区切ることだが、末尾を "10" で区切ってしまうと詰むことに注意する。

コード

テストコードを含む

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
using pll = pair<long long, long long>;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

#define REP(i, n) for (long long i = 0; i < (long long)(n); ++i)
#define REP2(i, a, b) for (long long i = a; i < (long long)(b); ++i)
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
template<class T> ostream& operator << (ostream &s, vector<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, deque<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, vector<vector<T> > P)
{ for (int i = 0; i < P.size(); ++i) { s << endl << P[i]; } return s << endl; }
template<class T> ostream& operator << (ostream &s, set<T> P)
{ for(auto it : P) { s << "<" << it << "> "; } return s; }
template<class T> ostream& operator << (ostream &s, multiset<T> P)
{ for(auto it : P) { s << "<" << it << "> "; } return s; }
template<class T1, class T2> ostream& operator << (ostream &s, map<T1,T2> P)
{ for(auto it : P) { s << "<" << it.first << "->" << it.second << "> "; } return s; }


// Sparse Table
template<class MeetSemiLattice> struct SparseTable {
    vector<vector<MeetSemiLattice> > dat;
    vector<int> height;
    
    SparseTable() { }
    SparseTable(const vector<MeetSemiLattice> &vec) { init(vec); }
    void init(const vector<MeetSemiLattice> &vec) {
        int n = (int)vec.size(), h = 1;
        while ((1<<h) < n) ++h;
        dat.assign(h, vector<MeetSemiLattice>(1<<h));
        height.assign(n+1, 0);
        for (int i = 2; i <= n; i++) height[i] = height[i>>1]+1;
        for (int i = 0; i < n; ++i) dat[0][i] = vec[i];
        for (int i = 1; i < h; ++i)
            for (int j = 0; j < n; ++j)
                dat[i][j] = min(dat[i-1][j], dat[i-1][min(j+(1<<(i-1)),n-1)]);
    }
    
    MeetSemiLattice get(int a, int b) {
        return min(dat[height[b-a]][a], dat[height[b-a]][b-(1<<height[b-a])]);
    }
};

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    vector<int> sa;    // sa[i] : the starting index of the i-th smallest suffix (i = 0, 1, ..., n)
    vector<int> rank;  // rank[sa[i]] = i
    vector<int> lcp;   // lcp[i]: the lcp of sa[i] and sa[i+1] (i = 0, 1, ..., n-1)
    SparseTable<int> st;  // use for calcultating lcp(i, j)

    // getter
    int& operator [] (int i) { return sa[i]; }
    const int& operator [] (int i) const { return sa[i]; }
    vector<int> get_sa() { return sa; }
    vector<int> get_rank() { return rank; }
    vector<int> get_lcp() { return lcp; }

    // constructor
    SuffixArray() {}
    SuffixArray(const Str& str_, bool no_limit_elements = false) : str(str_) {
        build_sa(no_limit_elements);
    }
    void init(const Str& str_, bool no_limit_elements = false) {
        str = str_;
        build_sa(no_limit_elements);
    }
    void build_sa(bool no_limit_elements = false) {
        vector<int> s;
        int num_of_chars = 256;
        if (!no_limit_elements) {
            for (int i = 0; i < (int)str.size(); ++i) {
                s.push_back(str[i] + 1);
            }
        } else {
            unordered_map<int,int> dict;
            for (int i = 0; i < (int)str.size(); ++i) {
                if (!dict.count(str[i])) dict[str[i]] = dict.size();
            }
            for (int i = 0; i < (int)str.size(); ++i) {
                s.push_back(dict[str[i]] + 1);
            }
            num_of_chars = (int)dict.size();
        }
        s.push_back(0);
        sa = sa_is(s, num_of_chars);
        build_lcp(s);
        build_sparse_table();
    }

    // SA-IS
    // num_of_chars: # of characters
    vector<int> sa_is(vector<int> &s, int num_of_chars) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(num_of_chars + 1, 0), sum_s(num_of_chars + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= num_of_chars; ++i) {
            sum_s[i] += sum_l[i];
            if (i < num_of_chars) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(num_of_chars + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // find min id that str.substr(sa[id]) >= T
    int lower_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], string::npos, T) < 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that str.substr(sa[id], T.size()) > T
    int upper_bound(const Str& T) {
        int left = -1, right = sa.size();
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (str.compare(sa[mid], T.size(), T) <= 0)
                left = mid;
            else
                right = mid;
        }
        return right;
    }

    // find min id that sa[id] >= str.substr(l, r-l)
    int lower_bound(int l, int r) {
        int left = -1, right = rank[l];
        while (right - left > 1) {
            int mid = (left + right) / 2;
            if (st.get(mid, rank[l]) < r - l) left = mid;
            else right = mid;
        }
        return right;
    }
    
    // search
    bool is_contain(const Str& T) {
        int lb = lower_bound(T);
        if (lb >= sa.size()) return false;
        return str.compare(sa[lb], T.size(), T) == 0;
    }

    // find lcp
    void build_lcp(const vector<int> &s) {
        int N = (int)s.size();
        rank.assign(N, 0), lcp.assign(N - 1, 0);
        for (int i = 0; i < N; ++i) rank[sa[i]] = i;
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rank[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rank[i] - 1] = h;
        }
    }
    
    // build sparse table for calculating lcp
    void build_sparse_table() {
        st.init(lcp);
    }

    // calc lcp of str.sutstr(a) and str.substr(b)
    int get_lcp(int a, int b) {
        return st.get(min(rank[a], rank[b]), max(rank[a], rank[b]));
    }

    // debug
    void dump() {
        for (int i = 0; i < sa.size(); ++i) {
            cout << i << ": " << sa[i] << ", " << str.substr(sa[i]) << endl;
        }
    }
};

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() noexcept : val(0) { }
    constexpr Fp(long long v) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const noexcept { return val; }
    constexpr int get_mod() const noexcept { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp &r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const noexcept {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const noexcept {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &r, long long n) noexcept {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD> &r) noexcept {
        return r.inv();
    }
};

// Suffix Array
string S;
SuffixArray<string> sa;

// S[i] 以降の最初の 0 でない index
vector<int> nex;

// S[l:r] が何番目か
long long order(int l, int r, bool debug = false) {
    long long tl = nex[l];
    if (tl >= r) return 0;
    long long len = r - tl;
    long long rank = sa.rank[tl];
    long long res = len * 10000 + sa.lower_bound(tl, r);
    return res;
}

class IncreasingSequence {
public:
    int getProduct(vector <string> digits) {
        S = "";
        for (auto s : digits) S += s;
        int N = S.size();
        nex.assign(N, N);
        for (int i = N-1; i >= 0; --i) {
            if (S[i] != '0') nex[i] = i;
            else if (i < N-1) nex[i] = nex[i+1];
        }
        sa.init(S);

        // forward
        const long long INF = 1LL<<60;
        vector<long long> dp(N+1, INF);
        dp[0] = 0;
        for (int i = 1; i <= N; ++i) {
            for (int j = 0; j < i; ++j) {
                if (order(j, i) > dp[j]) chmin(dp[i], order(j, i));
            }
        }
        
        // backward
        vector<long long> dp2(N+1, -1);
        vector<int> pre2(N+1);
        for (int i = 0; i < N; ++i) if (order(i, N) == dp.back()) {
            dp2[i] = dp.back();
            pre2[i] = N;
        }
        for (int i = N-1; i >= 0; --i) {
            for (int j = i+1; j < N; ++j) {
                if (order(i, j) < dp2[j]) {
                    if (chmax(dp2[i], order(i, j))) pre2[i] = j;
                }
            }
        }
    
        
        const int MOD = 1000000003;
        using mint = Fp<MOD>;
        auto calc = [&](int l, int r) -> mint {
            mint res = 0;
            for (int i = l; i < r; ++i) res = res * 10 + (S[i] - '0');
            return res;
        };
        
        mint res = 1;
        int i = 0;
        while (i != N) {
            assert(i >= 0 && i <= N);
            int i2 = pre2[i];
            res *= calc(i, i2);
            i = i2;
        }
        return res.get();
    }
};



// BEGIN CUT HERE
namespace moj_harness {
int run_test_case(int);
void run_test(int casenum = -1, bool quiet = false) {
    if (casenum != -1) {
        if (run_test_case(casenum) == -1 && !quiet) {
            cerr << "Illegal input! Test case " << casenum << " does not exist." << endl;
        }
        return;
    }
    
    int correct = 0, total = 0;
    for (int i=0; i <= 100; ++i) {
        int x = run_test_case(i);
        if (x == -1) {
            if (i >= 100) break;
            continue;
        }
        correct += x;
        ++total;
    }
    
    if (total == 0) {
        cerr << "No test cases run." << endl;
    } else if (correct < total) {
        cerr << "Some cases FAILED (passed " << correct << " of " << total << ")." << endl;
    } else {
        cerr << "All " << total << " tests passed!" << endl;
    }
}

int verify_case(int casenum, const int &expected, const int &received, clock_t elapsed) {
    cerr << "Example " << casenum << "... ";
    
    string verdict;
    vector<string> info;
    char buf[100];
    
    if (elapsed > CLOCKS_PER_SEC / 200) {
        sprintf(buf, "time %.2fs", elapsed * (1.0/CLOCKS_PER_SEC));
        info.push_back(buf);
    }
    
    if (expected == received) {
        verdict = "PASSED";
    } else {
        verdict = "FAILED";
    }
    
    cerr << verdict;
    if (!info.empty()) {
        cerr << " (";
        for (int i=0; i<(int)info.size(); ++i) {
            if (i > 0) cerr << ", ";
            cerr << info[i];
        }
        cerr << ")";
    }
    cerr << endl;
    
    if (verdict == "FAILED") {
        cerr << "    Expected: " << expected << endl;
        cerr << "    Received: " << received << endl;
    }
    
    return verdict == "PASSED";
}

int run_test_case(int casenum__) {
    switch (casenum__) {
        case 0: {
            string digits[]           = {"12345"};
            int expected__            = 120;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 1: {
            string digits[]           = {"543210"};
            int expected__            = 45150;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 2: {
            string digits[]           = {"20210222"};
            int expected__            = 932400;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 3: {
            string digits[]           = {"1111111111"};
            int expected__            = 1356531;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 4: {
            string digits[]           = {"171829294246"};
            int expected__            = 385769340;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 5: {
            string digits[]           = {"3","235","236"};
            int expected__            = 264320;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
            
        // custom cases
            
        case 8: {
            string digits[]           = {"1000001"};
            int expected__            = 1000001;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 9: {
            string digits[]           = {"10100010"};
            int expected__            = 1000100;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 10: {
            string digits[]           = {"1010100100"};
            int expected__            = 101101000;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 7: {
            string digits[]           = {"1"};
            int expected__            = 1;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        case 6: {
            string digits[]           = {};
            int expected__            = 1;
            
            clock_t start__           = clock();
            int received__            = IncreasingSequence().getProduct(vector <string>(digits, digits + (sizeof digits / sizeof digits[0])));
            return verify_case(casenum__, expected__, received__, clock()-start__);
        }
        default:
            return -1;
    }
    }
}
 

int main(int argc, char *argv[]) {
    if (argc == 1) {
        moj_harness::run_test();
    } else {
        for (int i=1; i<argc; ++i)
            moj_harness::run_test(atoi(argv[i]));
    }
}
// END CUT HERE