けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder Library Practice Contest J - Segment Tree

セグメント木の練習問題です。

クエリタイプ 1, 2 のみなら、ただの RMQ ですね。クエリタイプ 3 は、セグメント木上の二分探索を実行する関数 max_right() が使えます。

問題概要

長さ  N の数列  A_{1}, A_{2}, \dots, A_{N} がある。この数列に対して、以下の 2 種類のクエリに答えよ ( Q 個のクエリが与えられる)。

  • クエリ 1: (x, v) が与えられるので、 A_{x} の値を  v に更新せよ
  • クエリ 2: (l, r) が与えられるので、 A_{l}, A_{l+1}, \dots, A_{r-1}, A_{r} の最大値を求めよ
  • クエリ 3: (x, v) が与えられるので、 A_{x} \le j \le N, v \le A_{j} を満たす最小の  j を求めよ (存在しない場合は  N+1 を出力せよ)

制約

  •  1 \le N, Q \le 2 \times 10^{5}

考えたこと

クエリタイプ 1, 2 のみならば、次の鉄則本にある RMQ とほとんど同じですね。ただし、クエリ 2 で区間  (l, r) といったときに、区間が  A_{r} を含むかどうかは問題によって異なるので注意しましょう。今回の問題では含みます。

drken1215.hatenablog.com

それでは、クエリタイプ 3 を考えます。愚直な解法としては、次のような二分探索法で解けます。


配列  A の、区間  \lbrack x, j) の最大値が  v 未満であるような最大の  j を二分探索法で求める。

 j+1 が答えである。


この解法の計算量は  O((\log N)^{2}) となります (繰り返し回数: O(\log N)、判定処理: O(\log N))。これでも十分通るのですが、更なる高速化として「セグメント木上で二分探索する」という解法があります。それによって  O(\log N) の計算量となります。セグメント木の二分探索については、次の記事に記載があります。

algo-logic.info

ACL の segtree では、なんと、セグメント木上の二分探索を実行する関数 max_right() も提供されています。二分探索の判定関数を関数オブジェクトとして引き渡します。具体的な使い方はドキュメントを参照しましょう。

 

コード

#include <bits/stdc++.h>
#include <atcoder/segtree>
using namespace std;
using namespace atcoder;

// セグメント木のための二項演算関数 op と、単位元を返す関数 e
int op(int a, int b) { return max(a, b); }
int e() { return -1; }

// セグメント木上の二分探索のための判定関数 (v は入力から受け取る)
// 区間 [x, j) の最大値 seg_val が v 未満であるような最大の j を求めたい
int v;
bool f(int seg_val) { return seg_val < v; }

int main() {
    int N, Q;
    cin >> N >> Q;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // セグメント木の設定
    segtree<int, op, e> seg(A);

    // 各クエリ処理
    while (Q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int x, v;
            cin >> x >> v;
            --x;
            seg.set(x, v);
        } else if (t == 2) {
            int l, r;
            cin >> l >> r;
            --l;
            cout << seg.prod(l, r) << endl;
        } else if (t == 3) {
            int x;
            cin >> x >> v;
            --x;
            cout << seg.max_right<f>(x) + 1 << endl;
        }
    }
}

自前ライブラリでも AC

自前ライブラリでも AC しておきます。

#include<bits/stdc++.h>
using namespace std;

// Segment Tree
template<class Monoid> struct SegTree {
    using Func = function<Monoid(Monoid, Monoid)>;

    // core member
    int N;
    Func OP;
    Monoid IDENTITY;
    
    // inner data
    int offset;
    vector<Monoid> dat;

    // constructor
    SegTree() {}
    SegTree(int n, const Func &op, const Monoid &identity) {
        init(n, op, identity);
    }
    SegTree(const vector<Monoid> &v, const Func &op, const Monoid &identity) {
        init((int)v.size(), op, identity);
        build(v);
    }
    void init(int n, const Func &op, const Monoid &identity) {
        N = n;
        OP = op;
        IDENTITY = identity;
        offset = 1;
        while (offset < N) offset *= 2;
        dat.assign(offset * 2, IDENTITY);
    }
    void init(const vector<Monoid> &v, const Func &op, const Monoid &identity) {
        init((int)v.size(), op, identity);
        build(v);
    }
    void build(const vector<Monoid> &v) {
        assert(N == (int)v.size());
        for (int i = 0; i < N; ++i) dat[i + offset] = v[i];
        for (int k = offset - 1; k > 0; --k) dat[k] = OP(dat[k*2], dat[k*2+1]);
    }
    int size() const {
        return N;
    }
    Monoid operator [] (int a) const { return dat[a + offset]; }
    
    // update A[a], a is 0-indexed, O(log N)
    void set(int a, const Monoid &v) {
        int k = a + offset;
        dat[k] = v;
        while (k >>= 1) dat[k] = OP(dat[k*2], dat[k*2+1]);
    }
    
    // get [a, b), a and b are 0-indexed, O(log N)
    Monoid prod(int a, int b) {
        Monoid vleft = IDENTITY, vright = IDENTITY;
        for (int left = a + offset, right = b + offset; left < right;
        left >>= 1, right >>= 1) {
            if (left & 1) vleft = OP(vleft, dat[left++]);
            if (right & 1) vright = OP(dat[--right], vright);
        }
        return OP(vleft, vright);
    }
    Monoid all_prod() { return dat[1]; }
    
    // get max r that f(get(l, r)) = True (0-indexed), O(log N)
    // f(IDENTITY) need to be True
    int max_right(const function<bool(Monoid)> f, int l = 0) {
        if (l == N) return N;
        l += offset;
        Monoid sum = IDENTITY;
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(OP(sum, dat[l]))) {
                while (l < offset) {
                    l = l * 2;
                    if (f(OP(sum, dat[l]))) {
                        sum = OP(sum, dat[l]);
                        ++l;
                    }
                }
                return l - offset;
            }
            sum = OP(sum, dat[l]);
            ++l;
        } while ((l & -l) != l);  // stop if l = 2^e
        return N;
    }

    // get min l that f(get(l, r)) = True (0-indexed), O(log N)
    // f(IDENTITY) need to be True
    int min_left(const function<bool(Monoid)> f, int r = -1) {
        if (r == 0) return 0;
        if (r == -1) r = N;
        r += offset;
        Monoid sum = IDENTITY;
        do {
            --r;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(OP(dat[r], sum))) {
                while (r < offset) {
                    r = r * 2 + 1;
                    if (f(OP(dat[r], sum))) {
                        sum = OP(dat[r], sum);
                        --r;
                    }
                }
                return r + 1 - offset;
            }
            sum = OP(dat[r], sum);
        } while ((r & -r) != r);
        return 0;
    }
    
    // debug
    friend ostream& operator << (ostream &s, const SegTree &seg) {
        for (int i = 0; i < seg.size(); ++i) {
            s << seg[i];
            if (i != seg.size()-1) s << " ";
        }
        return s;
    }
};


// ACL practice J - Segment Tree
void ACL_practice_J() {
    int N, Q;
    cin >> N >> Q;
    vector<int> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    
    // セグ木の準備 (型: int, 演算方法: op, 単位元: -INF)
    const int INF = 1<<30;
    SegTree<int> seg(A, [&](int a, int b){ return max(a, b); }, -INF);
    
    // クエリ処理
    while (Q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int X, V;
            cin >> X >> V;
            --X;
            
            // A[x] を V に update する処理
            seg.set(X, V);
        } else if (t == 2) {
            int L, R;
            cin >> L >> R;
            --L;
            
            // A の区間 [L, R) の最大値を求める処理
            cout << seg.prod(L, R) << endl;
        } else {
            int L, V;
            cin >> L >> V;
            --L;
            
            // x = seg.prod(L, r) として、f(x) = True となる最大の r を求める
            int res = seg.max_right([&](int x) -> bool { return V > x; }, L);
            
            // 求めたいのは V <= prod(L, x) となる最小の x
            cout << res + 1 << endl;
        }
    }
}

int main() {
    ACL_practice_J();
}