けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

鉄則本 A58 - RMQ (Range Maximum Queries)

セグメント木の最初の練習問題によさそうな問題

問題概要

長さ  N の数列  A_{1}, A_{2}, \dots, A_{N} がある。最初はすべての要素が 0 となっている。この数列に対して、以下の 2 種類のクエリに答えよ ( Q 個のクエリが与えられる)。

  • クエリ 1: (p, x) が与えられるので、 A_{p} の値を  x に更新せよ
  • クエリ 2: (l, r) が与えられるので、 A_{l}, A_{l+1}, \dots, A_{r-1} の最大値を求めよ

制約

  •  1 \le N, Q \le 10^{5}

解法

クエリ 1 は「値の更新」に関するクエリで、クエリ 2 は「値の取得」に関するクエリとなっています。さらに言えば、

  • クエリ 1 は、数列の 1 要素に関する「値の更新」
  • クエリ 2 は、数列の区間全体に関する「値の取得」

となっています。このようなクエリを「1 点更新区間取得」であると言うことがあります。

このクエリを愚直に実行しようとすると、1 点更新は  O(1) の計算量でできるものの、区間取得は  O(N) の計算量を必要としています。区間の長さの分だけ要素を 1 つ 1 つ見る必要があるためです。

このようなとき、セグメント木が活躍することが多々あります。セグメント木を使うと、区間取得が  O(\log N) でできるようになります (その代わり 1 点更新にも  O(\log N) の計算量を要します)。セグメント木については、たとえば次の記事で詳しく解説されています。

algo-logic.info

tsutaj.hatenablog.com

 

セグメント木の使い方

ここではセグメント木を「なぜか区間に対する値が素早く決まる配列」と思うことにします。セグメント木を構成するためには、次の 3 つのものを定めます。


  • 配列の各要素の
  • 配列の各要素間の演算方法
  • その演算における単位元

たとえば、今回の問題のように「整数列の区間の最大値を求めよ」という問題であれば、

  • 型:int
  • 演算方法: f(a, b) = \max(a, b)
  • 単位元: -\infty

などとすればよいでしょう。なお、 e が単位元であるとは、 f(a, e) = a が常に成り立つようなもののことを言います。たとえば、

  • 演算方法が足し算  f(a, b) = a + b であれば、単位元は  0
  • 演算方法が掛け算  f(a, b) = ab であれば、単位元は  1
  • 演算方法が min を取る操作  f(a, b) = \min(a, b) であれば、単位元は  \infty

となります。その他の、セグ木の詳しい使い方等は次のコード例を参考に。

 

コード 1 (ACL を用いた実装)

ここでは、AtCoder の公式ライブラリである ACL に搭載されている segtree を用いた実装例を示します。なお、ACL の segtree のドキュメントは次のページにあります。

https://atcoder.github.io/ac-library/document_ja/segtree.html

#include<bits/stdc++.h>
#include<atcoder/segtree>
using namespace std;
using namespace atcoder;

// 十分大きい値
const int INF = 1<<30;

// 演算方法
int op(int a, int b) { return max(a, b); }

// 単位元
int e() { return -INF; }

int main() {
    int N, Q;
    cin >> N >> Q;
    
    // セグ木の準備 (型: int, 演算方法: op, 単位元: e)
    segtree<int, op, e> seg(N);
    
    // 初期状態では数列 A の値はすべて 0 であるため、各要素を 0 にする
    for (int i = 0; i < N; ++i) seg.set(i, 0);
    
    // クエリ処理
    while (Q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int pos, x;
            cin >> pos >> x;
            --pos;

            // A[pos] を x に更新する処理
            seg.set(pos, x);
        } else {
            int l, r;
            cin >> l >> r;
            --l, --r;

            // A の区間 [l, r) の最大値を求める処理
            cout << seg.prod(l, r) << endl;
        }
    }
}

コード 2 (自分のセグメント木)

次に、自前で用意したセグメント木でも AC します。

#include<bits/stdc++.h>
using namespace std;


// Segment Tree
template<class Monoid, Monoid(*OP)(Monoid, Monoid), Monoid IDENTITY> struct SegTree {
    // size of segtree-array
    int N;
    
    // inner data
    int offset;
    vector<Monoid> dat;

    // constructor
    SegTree() {}
    SegTree(int n) : N(n) {
        init(n);
    }
    SegTree(const vector<Monoid> &v) : N(v.size()) {
        init(v);
    }
    void init(int n) {
        N = n;
        offset = 1;
        while (offset < N) offset *= 2;
        dat.assign(offset * 2, IDENTITY);
    }
    void init(const vector<Monoid> &v) {
        N = (int)v.size();
        offset = 1;
        while (offset < N) offset *= 2;
        dat.assign(offset * 2, IDENTITY);
        for (int i = 0; i < N; ++i) dat[i + offset] = v[i];
        build();
    }
    void build() {
        for (int k = offset - 1; k > 0; --k) dat[k] = OP(dat[k*2], dat[k*2+1]);
    }
    int size() const {
        return N;
    }
    Monoid operator [] (int a) const { return dat[a + offset]; }
    
    // update A[a], a is 0-indexed, O(log N)
    void set(int a, const Monoid &v) {
        int k = a + offset;
        dat[k] = v;
        while (k >>= 1) dat[k] = OP(dat[k*2], dat[k*2+1]);
    }
    
    // get [a, b), a and b are 0-indexed, O(log N)
    Monoid prod(int a, int b) {
        Monoid vleft = IDENTITY, vright = IDENTITY;
        for (int left = a + offset, right = b + offset; left < right;
        left >>= 1, right >>= 1) {
            if (left & 1) vleft = OP(vleft, dat[left++]);
            if (right & 1) vright = OP(dat[--right], vright);
        }
        return OP(vleft, vright);
    }
    Monoid all_prod() { return dat[1]; }
    
    // get max r that f(get(l, r)) = True (0-indexed), O(log N)
    // f(IDENTITY) need to be True
    int max_right(const function<bool(Monoid)> f, int l = 0) {
        if (l == N) return N;
        l += offset;
        Monoid sum = IDENTITY;
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(OP(sum, dat[l]))) {
                while (l < offset) {
                    l = l * 2;
                    if (f(OP(sum, dat[l]))) {
                        sum = OP(sum, dat[l]);
                        ++l;
                    }
                }
                return l - offset;
            }
            sum = OP(sum, dat[l]);
            ++l;
        } while ((l & -l) != l);  // stop if l = 2^e
        return N;
    }

    // get min l that f(get(l, r)) = True (0-indexed), O(log N)
    // f(IDENTITY) need to be True
    int min_left(const function<bool(Monoid)> f, int r = -1) {
        if (r == 0) return 0;
        if (r == -1) r = N;
        r += offset;
        Monoid sum = IDENTITY;
        do {
            --r;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(OP(dat[r], sum))) {
                while (r < offset) {
                    r = r * 2 + 1;
                    if (f(OP(dat[r], sum))) {
                        sum = OP(dat[r], sum);
                        --r;
                    }
                }
                return r + 1 - offset;
            }
            sum = OP(dat[r], sum);
        } while ((r & -r) != r);
        return 0;
    }
    
    // debug
    friend ostream& operator << (ostream &s, const SegTree &seg) {
        for (int i = 0; i < seg.size(); ++i) {
            s << seg[i];
            if (i != seg.size()-1) s << " ";
        }
        return s;
    }
};

// 十分大きい値
const int INF = 1<<30;

// 演算方法
int op(int a, int b) { return max(a, b); }

int main() {
    int N, Q;
    cin >> N >> Q;
    
    // セグ木の準備 (型: int, 演算方法: op, 単位元: -INF)
    SegTree<int, op, -INF> seg(N);
    
    // 初期状態では数列 A の値はすべて 0 であるため、各要素を 0 にする
    for (int i = 0; i < N; ++i) seg.set(i, 0);
    
    // クエリ処理
    while (Q--) {
        int t;
        cin >> t;
        if (t == 1) {
            int pos, x;
            cin >> pos >> x;
            --pos;
            
            // A[pos] を x に更新する処理
            seg.set(pos, x);
        } else {
            int l, r;
            cin >> l >> r;
            --l, --r;
            
            // A の区間 [l, r) の最大値を求める処理
            cout << seg.prod(l, r) << endl;
        }
    }
}