けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

diverta 2019 E - XOR Partitioning (橙色, 800 点)

時間かかりすぎた。シンプルで面白い。

問題へのリンク

問題概要

 N 要素からなる 0 以上の整数列  a_1, a_2, \dots, a_N が与えられる。

これをいくつかの連続した部分列に分割する  2^{N-1} 通りの方法のうち、各連続区間の XOR 和が互いに等しくなるものが何通りあるか、1000000007 で割ったあまりを求めよ。

制約

  •  1 \le N \le 5 × 10^{5}
  •  0 \le a_i \le 2^{20}

考えたこと

「連続する区間」に関する問題なので、累積和をとるのはとても自然。累積和をとると

  •  S_0, S_1, \dots, S_N から  S_0, S_N を含むように何個か選んであげて
  • それが  0, d, 0, d, ... のように  0 d ( d は任意の整数) が交互に並ぶようなもの

を数え上げる問題になる。ここで  S_N \neq 0 の場合は  d = S_N の場合のみ考えればよく、 S_N = 0 の場合は  d = 0, 1, \dots, 2^{20} について考える必要が出てくる。

d = 0 のとき

単純に  S の中の  d の個数を  x 個として  2^{x-2}

それ以外のとき

 S_i = d となる i を  p_0, p_1, \dots, p_{K-1} とする。ここで  p_{0} p_{i} との間にある  0 の個数を  q_{i} とする。

このとき、

  •  {\rm dp} \lbrack i \rbrack :=  S_i を最後に選んだ場合の、そこまでの 0-d 列の個数

とすると dp[ i ] は

  • スタートの  S_0 = 0 からいきなり飛ぶ場合 1 通り
  •  j = 1, 2, ..., i-1 から飛ぶ方法は、 {\rm dp} \lbrack j \rbrack × (q \lbrack i-1 \rbrack - q \lbrack j-1 \rbrack) 通り

となる。よって

 {\rm dp} \lbrack i \rbrack = 1 + \sum_{j = 1}^{i-1} {\rm dp} \lbrack j \rbrack × (q \lbrack i-1 \rbrack - q \lbrack j-1 \rbrack)
 = 1 + q \lbrack i-1 \rbrack × \sum_{j = 1}^{i-1} {\rm dp} \lbrack j \rbrack  - \sum_{j = 1}^{i-1} {\rm dp} \lbrack j \rbrack q \lbrack j-1 \rbrack

となる。このうちの  \sum_{j = 1}^{i-1} {\rm dp} \lbrack j \rbrack \sum_{j = 1}^{i-1} {\rm dp} \lbrack j \rbrack q \lbrack j-1 \rbrack については累積和を用いることで高速化できる。

まとめ

結局、各  d に対して、 S 中の  d の個数  K に対して  O(K) の時間で求めることができたので、全部で  O(N) で求めることができる。

#include <iostream>
#include <vector>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
const int MOD = 1000000007;
using mint = Fp<MOD>;


int N;
vector<int> A, S;          // S: A の累積 XOR
vector<int> zs;            // zs[i] := S[0:i) の中の 0 の個数
vector<vector<int> > inds; // inds[v] := S の中の v のある index たち

mint subsolve(int d, bool iszero = false) { 
    auto p = inds[d];
    int K = (int)p.size();
    vector<int> q(K+1, 0);
    for (int i = 0; i < K; ++i) q[i] = zs[p[i]] - zs[p[0]];

    vector<mint> dp(K+1, 0), sdp(K+2, 0), sdp2(K+2, 0);
    dp[0] = 1;
    for (int i = 1; i <= K; ++i) {
        dp[i] = sdp[i] * q[i-1] - sdp2[i] + 1;
        sdp[i+1] = sdp[i] + dp[i];
        sdp2[i+1] = sdp2[i] + dp[i] * q[i-1];
    }
    mint res = 0;
    if (!iszero) for (int i = 1; i <= K; ++i) res += dp[i];
    else res = dp.back();   
    return res;
}

mint solve() {
    S.assign(N+1, 0);
    for (int i = 0; i < N; ++i) S[i+1] = S[i] ^ A[i];
    zs.assign(N+2, 0);
    for (int i = 0; i < N+1; ++i) zs[i+1] = zs[i] + (S[i] == 0);
    inds.assign((1<<20)+1, vector<int>());
    for (int i = 0; i < N+1; ++i) inds[S[i]].push_back(i);

    mint res = 0;
    if (S.back() == 0) {
        int zeronum = (int)inds[0].size() - 2;
        res += modpow(mint(2), zeronum); // 全て 0 のケース
        for (int d = 1; d < (1<<20); ++d) {
            if (inds[d].empty()) continue;
            res += subsolve(d);
        }
    }
    else res = subsolve(S.back(), true);
    return res;
}

int main() {
    cin >> N; A.resize(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    cout << solve() << endl;
}