けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces CodeCraft-20 (Div. 2) F. Battalion Strength (R2800)

実装がエグエグのエグだけど、実はなんと、遅延評価セグ木すら必要なくて、普通のセグ木だけあれば解けてしまう!

問題へのリンク

問題概要

 N 個の整数  a_{1}, a_{2}, \dots, a_{N} に対して定まる量  f を次のように定義する:

  •  a の部分集合を選ぶ  2^{N} 通りの方法から一様ランダムに選んで、さらにそれを小さい順にソートして  b_{1}, \dots, b_{k} とする
  •  X = b_{1} b_{2} + \dots + b_{k-1} b_{k} で定義される確率変数  X を考えたとき
  •  f = E \lbrack X \rbrack と定義する

まず与えられた数列についての  f の値を答えたあと、以下の  Q 個のクエリに答えよ

  • 各クエリは、数列のうちの一箇所の値を置き換えるものである
  • 置き換えた数列についての  f の値を出力せよ。

制約

  •  N, Q \le 3 \times 10^{5}

考えたこと

普通に考えると、数列の値書き換えは「値の削除」と「値の挿入」をこなす必要があるように思えるため、普通のセグ木じゃ対応できない問題に思える。

そこでテクニックとして、元の数列の index を、クエリの index を全部混ぜてサイズ  N+Q のセグ木を構築してしまうことにした。初期状態ではクエリの index に対応する部分は 0 といった値にしておいて、クエリが進行するにつれて、数列の置き換えられた部分は 0 に更新されて該当クエリ部分に値が入っていくイメージ。

そして、かなり上手にセグ木の状態を定義すれば、遅延評価すらしないセグ木で解くことができる。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

template<class Monoid> struct SegTree {
    using Func = function<Monoid(Monoid, Monoid)>;
    const Func F;
    const Monoid UNITY;
    int SIZE_R;
    vector<Monoid> dat;
    
    SegTree(int n, const Func f, const Monoid &unity): F(f), UNITY(unity) { init(n); }
    void init(int n) {
        SIZE_R = 1;
        while (SIZE_R < n) SIZE_R *= 2;
        dat.assign(SIZE_R * 2, UNITY);
    }
    
    /* set, a is 0-indexed */
    void set(int a, const Monoid &v) { dat[a + SIZE_R] = v; }
    void build() {
        for (int k = SIZE_R - 1; k > 0; --k)
            dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* update a, a is 0-indexed */
    void update(int a, const Monoid &v) {
        int k = a + SIZE_R;
        dat[k] = v;
        while (k >>= 1) dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* get [a, b), a and b are 0-indexed */
    Monoid get(int a, int b) {
        Monoid vleft = UNITY, vright = UNITY;
        for (int left = a + SIZE_R, right = b + SIZE_R; left < right; left >>= 1, right >>= 1) {
            if (left & 1) vleft = F(vleft, dat[left++]);
            if (right & 1) vright = F(dat[--right], vright);
        }                                                                                                              
        return F(vleft, vright);
    }
    inline Monoid operator [] (int a) { return dat[a + SIZE_R]; }
    
    /* debug */
    void print() {
        for (int i = 0; i < SIZE_R; ++i) {
            cout << (*this)[i];
            if (i != SIZE_R-1) cout << ",";
        }
        cout << endl;
    }
};

using pint = pair<int, int>; // (type, query_id)
using tup = pair<long long, pint>;

struct Node {
    mint val, left, right;
    long long num;
    Node(long long v = 0, long long n = 0) :
        val(0), left(v), right(v), num(n) {}
    friend ostream& operator << (ostream &os, const Node &n) {
        return os << "(" << n.val << ", " << n.left << ", " << n.right << ", " << n.num << ")";
    }
};
 
int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    int N; cin >> N;
    vector<long long> P(N);
    for (int i = 0; i < N; ++i) cin >> P[i];
    int Q; cin >> Q;
    vector<long long> I(Q), X(Q);
    for (int i = 0; i < Q; ++i) cin >> I[i] >> X[i], --I[i];
    vector<tup> all;
    for (int i = 0; i < N; ++i) all.push_back(tup(P[i], pint(0, -1)));
    for (int i = 0; i < Q; ++i) all.push_back(tup(X[i], pint(1, i)));
    sort(all.begin(), all.end());
    
    vector<mint> tpow(N+Q+5, 0);
    tpow[0] = 1;
    for (int i = 0; i + 1 < tpow.size(); ++i) tpow[i+1] = tpow[i] * 2;
    auto func = [&](const Node &a, const Node &b) {
        Node res;
        res.val = a.val * tpow[b.num] + b.val * tpow[a.num] + a.left * b.right;
        res.left = a.left + b.left * tpow[a.num];
        res.right = a.right * tpow[b.num] + b.right;
        res.num = a.num + b.num;
        return res;
    };
    Node unit(0, 0);
    SegTree<Node> seg(N + Q + 1, func, unit);
    vector<int> pid(Q);
    for (int i = 0; i < all.size(); ++i) {
        if (all[i].second.first == 0) {
            seg.set(i, Node(all[i].first, 1));
        }
        else {
            pid[all[i].second.second] = i;
            seg.set(i, unit);
        }
    }
    seg.build();
    cout << seg.get(0, N+Q).val / tpow[N] << endl;
    for (int i = 0; i < Q; ++i) {
        long long before = P[I[i]], after = X[i];
        tup t(before, pint(0, -1));
        int allid = lower_bound(all.begin(), all.end(), t) - all.begin();
        seg.update(allid, unit);
        P[I[i]] = after;
        all[allid] = tup(before, pint(-1, -1));
        seg.update(pid[i], Node(after, 1));
        cout << seg.get(0, N+Q).val / tpow[N] << endl;
    }
}