けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 3170 Freqs (HUPC 2020 day1-G)

難しかったー!!平方分割かな...とまでは思ったので、もっと粘り強く考えられるようにならないと...!

問題へのリンク

問題概要

長さ  N の数列  a_{1}, a_{2}, \dots, a_{N} が与えられます。以下の 4 種類のクエリを  Q 回処理してください。

  • クエリ 1:  l, r, x が与えられるので、 a_{i} ( l \le i \le r) を  \min(a_{i}, x) に更新する
  • クエリ 2:  l, r, x が与えられるので、 a_{i} ( l \le i \le r) を  \max(a_{i}, x) に更新する
  • クエリ 3:  l, r, x が与えられるので、 a_{i} ( l \le i \le r) を  a_{i} + x に更新する
  • クエリ 4:  l, r, x, y が与えられるので、 l \le i \le r かつ  x \le a_{i} \le y を満たす  i の個数を出力する

制約

  •  1 \le N, Q \le 10^{5}

考えたこと

平方分割系かなとは思った。でも、次のように考えてしまい、先に進めなかった。

  1. 各バケットごとにクエリ 1〜3 の操作列を保存しないといけなくて、操作列が長さ  Q だけあるので結局平方分割しても扱いきれない

  2. たとえクエリ 3 とクエリ 4 だけだとしても、クエリ 4 に対応するのはキツい

これらの難所をそれぞれクリアしていく!!!

操作列を単純化する

まず、次のことに着目すると、長さ  Q の操作列も、それと等価な長さ 3 の操作列に変換することができる!!!

  • (add x) して (chmin y) する = (chmin y - x) して (add x) する
  • (add x) して (chmax y) する = (chmax y - x) して (add x) する
  • (chmax x) して (chmin y) する = (chmin y) して (chmax min(x, y)) する

このように順序を入れ替えることができるので、操作列は

  • (chmin x) して (chmax y) して (add z) する

という「標準形」に変形することができる。

平方分割して各バケットごとにソート

次に、そもそも区間  \lbrack l, r) x 以上  y 以下のものが何個あるか、というクエリにどう答えたらよいかを考える。そのためには、仮に配列全体がソートされていればよいことに気づく。仮に配列全体がソートされていれば、二分探索によって  x 以上  y 以下の要素の個数がわかる。

僕はそれで、平衡二分探索木とかが必要なのかな...とか考えてしまって先に進めなくなってしまった。しかしよく考えたら、平方分割しておいて、各バケットごとにソートされていればそれでよかった!!!

まとめ

まとめると次のようにすれば OK。

前処理

  • まず配列全体をバケットごとに分割する
  • 各バケットごとに、生配列 A と、ソート後配列 B を用意しておく。

前処理の計算量は  O(N \log N)

区間更新クエリ

  • 区間更新クエリを、各バケットごとに処理する。このとき左右両端のバケット以外のバケットは、バケット全体にかかる更新であることに注意する
  • 左右両端のバケットは、予め遅延評価しておいた「chmin x、chmax y、add z」をバケット全体に実施した上で、新クエリに愚直に答える ( O(\sqrt{N} \log{N}))
  • それ以外のバケットは、操作列「chmin x、chmax y、add z」を適切に更新する ( O(\sqrt{N}))

計算量は、 O(\sqrt{N} \log{N})

区間取得クエリ

  • 左右両端のバケットは、予め遅延評価しておいた「chmin x、chmax y、add z」をバケット全体に実施した上で、愚直に数える ( O(\sqrt{N}))
  • それ以外のバケットは、「y 以下の個数」 - 「x-1 以下の個数」を計算すればよい。

ここで、「chmin x、chmax y、add z したときに v 以下」という条件は次のように言い換えられる

  • chmin x、chmax y、add z したときに v 以下
  • chmin x、chmax y したときに v-z 以下
  • chmin x したときに v-z 以下 (y <= v-z の場合に限る、それ以外は 0 個)
  • x <= v-z ならばすべて、それ以外なら v-z 以下

計算量は  O(\sqrt{N} \log{N})

以上をまとめると、全体の計算量は  O((N + Q\log N)\sqrt{N}) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

const long long INF = 1LL<<59;
struct SqrtDecomp {
    int SIZE = 200;
    vector<vector<long long>> array, sorted;
    vector<long long> vmin, vmax, vadd;
    
    SqrtDecomp(const vector<long long> &a) {
        for (int i = 0; i < a.size(); i += SIZE) {
            vector<long long> temp;
            for (int j = i; j < i + SIZE && j < a.size(); ++j) {
                temp.push_back(a[j]);
            }
            array.push_back(temp);
            sort(temp.begin(), temp.end());
            sorted.push_back(temp);
            vmin.push_back(INF), vmax.push_back(-INF), vadd.push_back(0);
        }
    }

    // relax delay
    inline void setsort(int id) {
        if (id >= array.size()) return;
        sorted[id] = array[id];
        sort(sorted[id].begin(), sorted[id].end());
    }
    inline void relax(int id) {
        if (id >= array.size()) return;
        for (int i = 0; i < array[id].size(); ++i) {
            chmin(array[id][i], vmin[id]);
            chmax(array[id][i], vmax[id]);
            array[id][i] += vadd[id];
        }
        vmin[id] = INF, vmax[id] = -INF, vadd[id] = 0;
    }

    // chmin a、chmax b、add c
    inline void act(int id, int i, long long a, long long b, long long c) {
        if (id >= array.size() || i >= array[id].size()) return;
        chmin(array[id][i], a);
        chmax(array[id][i], b);
        array[id][i] += c;
    }

    // chmin a、chmax b、add c
    inline void delay(int id, long long a, long long b, long long c) {
        if (id >= array.size()) return;
        chmin(vmin[id], a - vadd[id]);
        chmin(vmax[id], a - vadd[id]);
        chmax(vmax[id], b - vadd[id]);
        vadd[id] += c;
    }

    // chmin a、chmax b、add c
    void update(int left, int right, 
                long long a, long long b, long long c) {
        int lq = left / SIZE, lr = left % SIZE;
        int rq = right / SIZE, rr = right % SIZE;
        if (lq == rq) {
            relax(lq);
            for (int i = lr; i < rr; ++i) act(lq, i, a, b, c);
            setsort(lq);
            return;
        }
        relax(lq), relax(rq);
        for (int i = lr; i < array[lq].size(); ++i) act(lq, i, a, b, c);
        for (int i = 0; i < rr; ++i) act(rq, i, a, b, c);
        setsort(lq), setsort(rq);
        for (int id = lq + 1; id < rq; ++id) delay(id, a, b, c);
    }

    // getter
    inline long long get(int id, long long v) {
        v -= vadd[id];
        if (vmax[id] > v) return 0;
        if (vmin[id] <= v) return array[id].size();
        return upper_bound(sorted[id].begin(), sorted[id].end(), v) - sorted[id].begin();
    }
    long long get(int left, int right, long long v) {
        int lq = left / SIZE, lr = left % SIZE;
        int rq = right / SIZE, rr = right % SIZE;
        long long res = 0;
        if (lq == rq) {
            relax(lq), setsort(lq);
            for (int i = lr; i < rr; ++i) if (array[lq][i] <= v) ++res;
            return res;
        }
        relax(lq), relax(rq), setsort(lq), setsort(rq);
        for (int i = lr; i < array[lq].size(); ++i) if (array[lq][i] <= v) ++res;
        for (int i = 0; i < rr; ++i) if (array[rq][i] <= v) ++res;
        for (int id = lq + 1; id < rq; ++id) res += get(id, v);
        return res;
    }
};

int main() {
    int N, Q;
    scanf("%d %d", &N, &Q);
    vector<long long> a(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    SqrtDecomp sd(a);
    for (int q = 0; q < Q; ++q) {
        int type, left, right;
        long long x, y;
        scanf("%d %d %d %lld", &type, &left, &right, &x);
        if (type == 1) sd.update(left-1, right, x, -INF, 0);
        else if (type == 2) sd.update(left-1, right, INF, x, 0);
        else if (type == 3) sd.update(left-1, right, INF, -INF, x);
        else {
            scanf("%lld", &y);
            cout << sd.get(left-1, right, y) - sd.get(left-1, right, x-1) << endl;
        }
    }
}