けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

全国統一プログラミング王決定戦 本選 D - Deforestation (500 点)

遅延評価セグメントツリーで殴った...速度重視ならそれでいい気もする

問題へのリンク

問題概要

 N 本の竹があって、時刻 0 においてすべての竹の高さは 0 である。それぞれの竹は時刻が 1 経過するごとに高さが 1 増える。

竹を伐採するイベントが  M 回予定されていて、 i 番目のイベントは時刻  T_i に行われ、そのイベントでは番号が  L_i 以上  R_i 以下の竹がすべて伐採され、竹の長さが 0 になる。このとき、その時点での竹の高さの合計値がポイントとして加算される。また、竹は伐採された後も伸び続ける。

最終的に得られるポイントが何点になるか求めよ。

制約

  •  1 \le N, M \le 2 × 10^{5}

考えたこと

結局以下のような問題に等しい。


  •  N 要素の配列があって初期値はすべて 0 である。
  •  i 番目のクエリでは区間 [  L_i, R_i ] について値を  T_i と比べて大きい方に更新する
  • 最終状態において配列の各値の総和を求めて出力する

解法 1: 遅延評価セグメントツリーで殴る (本番これでやった)

以下のことができるセグ木があれば OK!

  • 更新 (l, r, v): 区間 [l, r) の値を v と比べて大きい方に更新する
  • 取得 i: i の値を取得する

通常のセグメント木は、更新は 1 点に対してのみで、取得が区間クエリであることが多い。これに対して更新も区間クエリに対応する技術として、遅延評価と呼ばれるものがあって、つたじぇー⭐️さんの以下の記事に詳細があります!

これで殴る。なお、遅延評価セグメントツリーは作用つきモノイドを扱うものとみなすことができて、一般に「モノイド同士の演算」「モノイドへの作用」「作用の合成」を表す演算を定義することで定まります。今回の場合は

- モノイド同士の演算: auto fm = [](long long a, long long b) { return max(a, b); };
- モノイドへの作用: auto fa = [](long long &a, long long d) { a = max(a, d); };
- 作用の合成: auto fl = [](long long &d, long long e) { d = max(d, e); };

としています (モノイド同士の演算は、今回は割と何でもいいです)。

#include <iostream>
#include <vector>
#include <functional>
using namespace std;

// Segment Tree
template<class Monoid, class Action> struct SegTree {
    using FuncMonoid = function< Monoid(Monoid, Monoid) >;
    using FuncAction = function< void(Monoid&, Action) >;
    using FuncLazy = function< void(Action&, Action) >;
    FuncMonoid FM;
    FuncAction FA;
    FuncLazy FL;
    Monoid UNITY_MONOID;
    Action UNITY_LAZY;
    int SIZE, HEIGHT;
    vector<Monoid> dat;
    vector<Action> lazy;
    
    SegTree() { }
    SegTree(int n, const FuncMonoid fm, const FuncAction fa, const FuncLazy fl,
            const Monoid &unity_monoid, const Action &unity_lazy)
    : FM(fm), FA(fa), FL(fl), UNITY_MONOID(unity_monoid), UNITY_LAZY(unity_lazy) {
        SIZE = 1; HEIGHT = 0;
        while (SIZE < n) SIZE <<= 1, ++HEIGHT;
        dat.assign(SIZE * 2, UNITY_MONOID);
        lazy.assign(SIZE * 2, UNITY_LAZY);
    }
    void init(int n, const FuncMonoid fm, const FuncAction fa, const FuncLazy fl,
              const Monoid &unity_monoid, const Action &unity_lazy) {
        FM = fm; FA = fa; FL = fl;
        UNITY_MONOID = unity_monoid; UNITY_LAZY = unity_lazy;
        SIZE = 1; HEIGHT = 0;
        while (SIZE < n) SIZE <<= 1, ++HEIGHT;
        dat.assign(SIZE * 2, UNITY_MONOID);
        lazy.assign(SIZE * 2, UNITY_LAZY);
    }
    
    /* set, a is 0-indexed */
    void set(int a, const Monoid &v) { dat[a + SIZE] = v; }
    void build() {
        for (int k = SIZE - 1; k > 0; --k)
            dat[k] = FM(dat[k*2], dat[k*2+1]);
    }
    
    /* update [a, b) */
    inline void evaluate(int k) {
        if (lazy[k] == UNITY_LAZY) return;
        if (k < SIZE) FL(lazy[k*2], lazy[k]), FL(lazy[k*2+1], lazy[k]);
        FA(dat[k], lazy[k]);
        lazy[k] = UNITY_LAZY;
    }
    inline void update(int a, int b, const Action &v, int k, int l, int r) {
        evaluate(k);
        if (a <= l && r <= b)  FL(lazy[k], v), evaluate(k);
        else if (a < r && l < b) {
            update(a, b, v, k*2, l, (l+r)>>1), update(a, b, v, k*2+1, (l+r)>>1, r);
            dat[k] = FM(dat[k*2], dat[k*2+1]);
        }
    }
    inline void update(int a, int b, const Action &v) { update(a, b, v, 1, 0, SIZE); }
    
    /* get [a, b) */
    inline Monoid get(int a, int b, int k, int l, int r) {
        evaluate(k);
        if (a <= l && r <= b)
            return dat[k];
        else if (a < r && l < b)
            return FM(get(a, b, k*2, l, (l+r)>>1), get(a, b, k*2+1, (l+r)>>1, r));
        else
            return UNITY_MONOID;
    }
    inline Monoid get(int a, int b) { return get(a, b, 1, 0, SIZE); }
    inline Monoid operator [] (int a) { return get(a, a+1); }
    
    /* debug */
    void print() {
        for (int i = 0; i < SIZE; ++i) { cout << (*this)[i]; if (i != SIZE) cout << ","; }
        cout << endl;
    }
};

int main() {
    int N, M;
    while (cin >> N >> M) {
        auto fm = [](long long a, long long b) { return max(a, b); };
        auto fa = [](long long &a, long long d) { a = max(a, d); };
        auto fl = [](long long &d, long long e) { d = max(d, e); };
        SegTree<long long, long long> seg(N+1, fm, fa, fl, 0, 0);
        
        for (int q = 0; q < M; ++q) {
            long long t;
            int a, b;
            cin >> t >> a >> b;
            --a;
            seg.update(a, b, t);
        }
        long long res  = 0;
        for (int i = 0; i < N; ++i) {
            res += seg.get(i, i+1);
        }
        cout << res << endl;
            
    }
}

解法 2: 操作列を後ろから見る

今回のような問題では、「一連の操作は、各ステップにおいて過去の履歴を破壊してしまう」という性質を持っている。区間 [l, r) を値 v で書き換えてしまうとき、もはや過去がどうだったからは関係なくなってしまう。

そのような操作列を扱うときにはしばしば「操作列を後ろから逆順に見る」という視点が功を奏すイメージがある。

後ろから見ると

  • 区間 [l, r) について時刻 t に伐採したとき、その区間のスコアは t で確定する (その後変更は行われない)

という感じになっている。よって

  • 整数全体を予め set で持っておく
  • 区間クエリ [l, r), t が来るたびに、set 中の [l, r) に含まれる要素を消しながら、答えに t を加算していく

という風にすれば OK。結局各要素はたかだか 1 回ずつ消されることになるので計算量は  O(N\log{N} + M) になる。

#include <iostream>
#include <vector>
#include <set>
using namespace std;

int main() {
    int N, M; cin >> N >> M;
    vector<int> T(M), L(M), R(M);
    for (int i = 0; i < M; ++i) cin >> T[i] >> L[i] >> R[i], --L[i];
    
    set<int> se;
    for (int i = 0; i < N; ++i) se.insert(i);
    
    long long res = 0;
    for (int i = M-1; i >= 0; --i) { // 逆順に
        auto it = se.lower_bound(L[i]); // set の中から L[i] 以上となるポインタを見つける
        
        // set 内を走査していく、値が R[i] 以上になったらその時点で打ち切る
        while (it != se.end() && *it < R[i]) {
            res += T[i];
            auto it2 = it; // set を破壊する操作が怖いので避けておいて
            ++it2; // 予め次の要素へとインクリメントしておく
            se.erase(it); // it を削除する
            it = it2; // 次の要素へ
        }
    }
    cout << res << endl;
}

解法 3: いもす法的に解く

この問題、操作が「v との max をとる」のところを「v を加算する」にすると imos 法っぽくなることに気づく。

そこで今回も各区間クエリに対して、その両端の情報を貯めておいて、最後にそれを総合するいもす法的方法がとれる。各情報につき、区間の L 側で insert し、R 側で erase する感じの実装をする。

何をするのかについては実装例を見るとすぐにわかりそう。

(アルメリアさんの記事もその方法でやってそう)

#include <iostream>
#include <vector>
#include <set>
using namespace std;

int main() {
    // ins[i] := i を区間の始点とするような時刻 t の集まり. era[i] も同様
    int N, M; cin >> N >> M;
    vector<vector<int> > ins(N+1), era(N+1);
    for (int i = 0; i < M; ++i) {
        int T, L, R; cin >> T >> L >> R; --L;
        ins[L].push_back(T);
        era[R].push_back(T);
    }

    // いもす法
    set<int> se; // その点その点についての時刻イベントを平面走査していく
    se.insert(0); // 番兵
    long long res = 0;
    for (int i = 0; i <= N; ++i) {
        for (auto t : ins[i]) se.insert(t);
        for (auto t : era[i]) se.erase(t);

        // この時点で se が、時刻 i が「どの時刻に伐採されたか」を表す集合になっている
        res += *prev(se.end());
    }
    cout << res << endl;
}