けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 153 F - Silver Fox vs Monster (1D, 水色, 600 点)

区間加算に対応したデータ構造の出番!

問題へのリンク

問題概要

 N 体のモンスターがいて、それぞれ座標  x_{i} にいて、HP は  H_{i} である。すべてのモンスターを倒したい。

1 回の魔法で、座標  x を指定して、[  x-D, x+D ] の範囲内にいるモンスターの HP をすべて  A ずつ減少することができる。モンスターは HP が 0 以下の状態を倒したとみなす。

すべてのモンスターの HP を 0 以下にするのに要する魔法回数の最小値を求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

まずこの手の区間爆発系の問題で考えることは、

  • モンスターのいるギリギリの範囲で爆発させるもののみを考えれば良い

というのがメチャメチャよくある。幾何だと特に!!!先々週の最小包含円なんかも、結局「3 点を通る円」や「2 点を直径とする円」のみを考えれば良い、という論法だった。

具体的には、下図のように、左端にモンスターがいるような場所で爆発させるもののみを考えれば OK。

そうすると、解法はとてもシンプルで

  • まず一番左にいるモンスターを左端として、その HP が 0 以下になるまで爆発させる
  • 次に残っているモンスターのうち、一番左にいるモンスターを左端として、その HP が 0 以下になるまで爆発させる
  • ...

というのを、全モンスターの HP が 0 以下になるまで繰り返せば OK。しかしこのままでは、一度の爆発で同時に  O(N) 体のモンスターの HP を減らす必要がある可能性があることから、愚直にやると  O(N^{2}) の計算時間を必要としてしまう。そこで様々な高速化方法が考えられる

  1. BIT や遅延セグ木で殴る
  2. imos しながら
  3. スライド最小値的にやる (queue や deque を使う)

なお、1 の BIT や imos では「爆発区間の右端を求める」のに二分探索をしているが、そこもそれぞれ尺取り法を使えば少し高速化できる。

高速化 (1): BIT や遅延セグ木で殴る

そこで、以下のことが高速にできるデータ構造があると嬉しい:

  • 区間 [l, r) に値 a を加算する
  • 位置 l の値を取得する

これができるデータ構造にはいろいろあると思う。BIT を少し工夫すると


  • 区間 [l, r) に値 a を加算する
  • 区間 [l, r) の合計値を取得する

というのを、ともに  O(\log{N}) で処理することができる。必要な操作の完全上位互換だけど、これを持っていれば貼るだけで OK。計算量は  O(N\log{N}) となる。

https://github.com/drken1215/algorithm/blob/master/DataStructure/binary_indexed_tree_RAQ.cppgithub.com

解説は蟻本の P.163 以降にある!!!

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

// 区間加算にも対応した BIT
template <class Abel> struct BIT {
    vector<Abel> dat[2];
    Abel UNITY_SUM = 0;                     // to be set
    
    /* [1, n] */
    BIT(int n) { init(n); }
    void init(int n) { for (int iter = 0; iter < 2; ++iter) dat[iter].assign(n + 1, UNITY_SUM); }
    
    /* a, b are 1-indexed, [a, b) */
    inline void sub_add(int p, int a, Abel x) {
        for (int i = a; i < (int)dat[p].size(); i += i & -i)
            dat[p][i] = dat[p][i] + x;
    }
    inline void add(int a, int b, Abel x) {
        sub_add(0, a, x * -(a - 1)); sub_add(1, a, x); sub_add(0, b, x * (b - 1)); sub_add(1, b, x * (-1));
    }
    
    /* a is 1-indexed, [a, b) */
    inline Abel sub_sum(int p, int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i) res = res + dat[p][i];
        return res;
    }
    inline Abel sum(int a, int b) {
        return sub_sum(0, b - 1) + sub_sum(1, b - 1) * (b - 1) - sub_sum(0, a - 1) - sub_sum(1, a - 1) * (a - 1);
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat[0].size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

using pll = pair<long long, long long>;
int N;
long long D, A;
vector<long long> X, H;

long long solve() {
    // モンスターを X が小さい順に
    vector<int> ids(N);
    for (int i = 0; i < N; ++i) ids[i] = i;
    sort(ids.begin(), ids.end(), [&](int i, int j) {
            return X[i] < X[j]; });
    vector<long long> nX(N), nH(N);
    for (int i = 0; i < N; ++i) nX[i] = X[ids[i]], nH[i] = H[ids[i]];
    X = nX, H = nH;

    // BIT で処理していく
    BIT<long long> bit(N+10);
    for (int i = 0; i < N; ++i) bit.add(i+1, i+2, H[i]); // 初期化
    long long res = 0;
    for (int i = 0; i < N; ++i) {
        long long cur = bit.sum(i+1, i+2);
        if (cur <= 0) continue;

        // モンスター i を倒すのに必要な回数
        long long need = (cur + A - 1) / A;

        // X[i] を左端とした爆発が届く範囲を求める
        long long right = X[i] + D * 2;
        int id = upper_bound(X.begin(), X.end(), right) - X.begin();

        // 爆発させる
        bit.add(i+1, id+1, -need * A);
        res += need;
    }
    return res;
}

int main() {
    cin >> N >> D >> A;
    X.resize(N); H.resize(N);
    for (int i = 0; i < N; ++i) cin >> X[i] >> H[i];
    cout << solve() << endl;
}

高速化 (2): imos しながら処理する

区間に値を加算するといえば、imos 法!!!!!
ただ普通 imos 法といえば「区間加算処理が全部終わってから最後に累積和をとって結果を見る」というイメージが強い。今回は

  • 値を取得する
  • 区間に加算する

というのが交互に現れるので、一見すると imos 法でできないように思えてしまう。でも imos できるのだ。毎回のステップで、「途中まで累積和をとる」みたいな感じにすれば OK。

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using pll = pair<long long, long long>;

int N;
long long D, A;
vector<long long> X, H;

long long solve() {
    // モンスターを X が小さい順に
    vector<int> ids(N);
    for (int i = 0; i < N; ++i) ids[i] = i;
    sort(ids.begin(), ids.end(), [&](int i, int j) {
            return X[i] < X[j]; });
    vector<long long> nX(N), nH(N);
    for (int i = 0; i < N; ++i) nX[i] = X[ids[i]], nH[i] = H[ids[i]];
    X = nX, H = nH;

    // imos 法で処理していく
    vector<long long> S(N+1, 0);
    long long res = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i] < H[i]) {
            // モンスター i を倒すのに必要な回数
            long long need = (H[i] - S[i] + A - 1) / A;
        
            // X[i] を左端とした爆発が届く範囲を求める
            long long right = X[i] + D * 2;
            int id = upper_bound(X.begin(), X.end(), right) - X.begin();
            
            // imos しながら爆発させる
            S[i] += need * A;
            S[id] -= need * A;
            res += need;
        }

        // imos の累積和をとる操作
        S[i+1] += S[i];
    }
    return res;
}

int main() {
    cin >> N >> D >> A;
    X.resize(N); H.resize(N);
    for (int i = 0; i < N; ++i) cin >> X[i] >> H[i];
    cout << solve() << endl;
}

高速化 (3): スライド最小値的にやる (queue や deque を使う)

区間に関する処理を順次行っていくのに使える手法として、とくに

  • 区間の左端も右端も単調に増加していく

という場合に、スライド最小値的なアプローチがとれることがある。今回は、キューに (どの x 座標まで爆発させたか、爆発回数) という値を管理して、それとは別に爆発回数 num という変数を管理しておいて、時刻順に処理していく。

  • 毎回のモンスターに対して、キューの先頭の示す座標がモンスターより左側にある限りは、キューの先頭を pop して、そこで爆発が無効になるので num を que.front().second だけ減らす
  • 見ているモンスターが生き残っているときは、新たにダメージを加えて、ダメージイベントを新たにキューに push する

という感じ。コードを見るのが早そう。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

using pll = pair<long long, long long>;
int N;
long long D, A;
vector<long long> X, H;

long long solve() {
    vector<int> ids(N);
    iota(ids.begin(), ids.end(), 0);
    sort(ids.begin(), ids.end(), [&](int i, int j) {
            return X[i] < X[j];});

    long long res = 0;
    queue<pll> que;
    long long num = 0;
    for (auto i : ids) {
        while (!que.empty() && que.front().first < X[i]) {
            num -= que.front().second;
            que.pop();
        }
        H[i] -= num;
        if (H[i] <= 0) continue;
        res += H[i];
        que.push({X[i] + D*2, H[i]});
        num += H[i];
    }
    return res;
}

int main() {
    while (cin >> N >> D >> A) {
        X.resize(N); H.resize(N);
        for (int i = 0; i < N; ++i) {
            cin >> X[i] >> H[i];
            H[i] = (H[i] + A - 1) / A;
        }
        cout << solve() << endl;
    }
}