けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 099 F - Eating Symbols Hard (赤色, 1200 点)

楽しかった。こういうのでロリハ使うの楽しい。発想自体は Zero-Sum Ranges (200 点) と似てる。

問題へのリンク

問題概要

高橋君は、いつも頭の中に長さ 2000000001 の数列  A_{-1000000000}, \dots, A_{-1}, A_{0}, A_{1}, \dots, A_{1000000000} と、整数値  P を思い浮かべている。初期状態では、数列の各要素値と、 P の値はすべて 0 である。

ここで長さ  N の文字列  S が与えられる。各文字は "+-><" のいずれかであり

  • '+' のとき、 A_{P} をインクリメントする
  • '-' のとき、 A_{P} をデクリメントする
  • '>' のとき、 P をインクリメントする
  • '<' のとき、 P をデクリメントする

この  N 回の操作を行ってできる結果の数列を  T とする。 S の空でない部分区間であって、その区間の操作のみを初期状態の数列と  P に対して行ってできる結果の数列が、 T に一致するものを数え上げよ ( P の値は一致しなくてよい)

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

 O(N^{2}) 個考えられる区間を高速に扱う方法なんて限られている。しゃくとり法などが代表的だけど、とてもそれが適用できる雰囲気の問題ではない。一つあるのは、

  • 累積和の要領で、区間 [0,  i ) に関する状態をハッシュ化して、
  •  i に対して、区間 [  i,  j ) が条件を満たすような  j がどのようなハッシュ値をもつべきかを求め、そのような  j をカウントする

という考え方。この考え方の最も簡単な場合が、Zero-Sum Ranges といえる。

ハッシュ化

数列の状態は以下のようにして自然にハッシュ化できる。 B を適当な値、 M を素数として、

  •  \sum_{i} A_{i} \times B^{i} (mod. M)

 i が負になる部分もあるが、問題ない。これはロリハそのものでもある。 i 個目の操作を終えた段階でのハッシュ値を  H_{i}、その段階での  P の値を  P_{i} とおく。こうしておくと、各  i に対して、ハッシュ値が

  •  H_{j} = H_{i} + H_{N} \times B^{P_{i}}

となるような区間 [0,  j) (  j \gt i) の個数をカウントすれば OK

ハッシュの衝突確率

ハッシュの衝突確率に関する議論は、公式解説にある。

https://img.atcoder.jp/arc099/editorial.pdf

僕の場合、(MOD, BASE) の組が 1 組のみだと衝突した。2 組にしたら通った。

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;

long long modinv(long long a, long long mod) {
    long long b = mod, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    u %= mod;
    if (u < 0) u += mod;
    return u;
}

long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    if (n < 0) {
        a = modinv(a, mod);
        n = -n;
        return modpow(a, n, mod);
    }
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

const vector<long long> MOD = {1000000009, 1000000007};
const vector<long long> BASE = {17, 1009};

long long solve(int N, const string &S) {
    vector<vector<long long>> hash(2, vector<long long>(N+1, 0));
    vector<vector<long long>> pval(2, vector<long long>(N+1, 0));
    map<pair<long long, long long>, vector<int>> pos;

    for (int i = 0; i < N; ++i) {
        for (int it = 0; it < 2; ++it) {
            if (S[i] == '>') {
                hash[it][i+1] = hash[it][i];
                pval[it][i+1] = pval[it][i] + 1;
            }
            else if (S[i] == '<') {
                hash[it][i+1] = hash[it][i];
                pval[it][i+1] = pval[it][i] - 1;
            }
            else if (S[i] == '+') {
                pval[it][i+1] = pval[it][i];
                long long add = modpow(BASE[it], pval[it][i], MOD[it]);
                hash[it][i+1] = (hash[it][i] + add) % MOD[it];
            }
            else if (S[i] == '-') {
                pval[it][i+1] = pval[it][i];
                long long add = modpow(BASE[it], pval[it][i], MOD[it]);
                hash[it][i+1] = (hash[it][i] - add + MOD[it]) % MOD[it];
            }
        }
        pos[{hash[0][i+1], hash[1][i+1]}].push_back(i+1);
    }

    long long res = 0;
    for (int i = 0; i <= N; ++i) {
        vector<long long> risou_add(2), risou(2);
        for (int it = 0; it < 2; ++it) {
            risou_add[it] = hash[it][N] * modpow(BASE[it], pval[it][i], MOD[it]) % MOD[it];
            risou[it] = (hash[it][i] + risou_add[it]) % MOD[it];
        }
        auto &v = pos[{risou[0], risou[1]}];
        int it = upper_bound(v.begin(), v.end(), i) - v.begin();
        res += (int)v.size() - it;  
    }
    return res;
}

int main() {
    int N; string S;
    while (cin >> N >> S) cout << solve(N, S) << endl;
}