難しかったー!!平方分割かな...とまでは思ったので、もっと粘り強く考えられるようにならないと...!
問題概要
長さ の数列 が与えられます。以下の 4 種類のクエリを 回処理してください。
- クエリ 1: が与えられるので、 () を に更新する
- クエリ 2: が与えられるので、 () を に更新する
- クエリ 3: が与えられるので、 () を に更新する
- クエリ 4: が与えられるので、 かつ を満たす の個数を出力する
制約
考えたこと
平方分割系かなとは思った。でも、次のように考えてしまい、先に進めなかった。
各バケットごとにクエリ 1〜3 の操作列を保存しないといけなくて、操作列が長さ だけあるので結局平方分割しても扱いきれない
たとえクエリ 3 とクエリ 4 だけだとしても、クエリ 4 に対応するのはキツい
これらの難所をそれぞれクリアしていく!!!
操作列を単純化する
まず、次のことに着目すると、長さ の操作列も、それと等価な長さ 3 の操作列に変換することができる!!!
- (add x) して (chmin y) する = (chmin y - x) して (add x) する
- (add x) して (chmax y) する = (chmax y - x) して (add x) する
- (chmax x) して (chmin y) する = (chmin y) して (chmax min(x, y)) する
このように順序を入れ替えることができるので、操作列は
- (chmin x) して (chmax y) して (add z) する
という「標準形」に変形することができる。
平方分割して各バケットごとにソート
次に、そもそも区間 に 以上 以下のものが何個あるか、というクエリにどう答えたらよいかを考える。そのためには、仮に配列全体がソートされていればよいことに気づく。仮に配列全体がソートされていれば、二分探索によって 以上 以下の要素の個数がわかる。
僕はそれで、平衡二分探索木とかが必要なのかな...とか考えてしまって先に進めなくなってしまった。しかしよく考えたら、平方分割しておいて、各バケットごとにソートされていればそれでよかった!!!
まとめ
まとめると次のようにすれば OK。
前処理
- まず配列全体をバケットごとに分割する
- 各バケットごとに、生配列 A と、ソート後配列 B を用意しておく。
前処理の計算量は
区間更新クエリ
- 区間更新クエリを、各バケットごとに処理する。このとき左右両端のバケット以外のバケットは、バケット全体にかかる更新であることに注意する
- 左右両端のバケットは、予め遅延評価しておいた「chmin x、chmax y、add z」をバケット全体に実施した上で、新クエリに愚直に答える ()
- それ以外のバケットは、操作列「chmin x、chmax y、add z」を適切に更新する ()
計算量は、)
区間取得クエリ
- 左右両端のバケットは、予め遅延評価しておいた「chmin x、chmax y、add z」をバケット全体に実施した上で、愚直に数える ()
- それ以外のバケットは、「y 以下の個数」 - 「x-1 以下の個数」を計算すればよい。
ここで、「chmin x、chmax y、add z したときに v 以下」という条件は次のように言い換えられる
- chmin x、chmax y、add z したときに v 以下
- chmin x、chmax y したときに v-z 以下
- chmin x したときに v-z 以下 (y <= v-z の場合に限る、それ以外は 0 個)
- x <= v-z ならばすべて、それ以外なら v-z 以下
計算量は 。
以上をまとめると、全体の計算量は となる。
#include <bits/stdc++.h> using namespace std; template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; } template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; } const long long INF = 1LL<<59; struct SqrtDecomp { int SIZE = 200; vector<vector<long long>> array, sorted; vector<long long> vmin, vmax, vadd; SqrtDecomp(const vector<long long> &a) { for (int i = 0; i < a.size(); i += SIZE) { vector<long long> temp; for (int j = i; j < i + SIZE && j < a.size(); ++j) { temp.push_back(a[j]); } array.push_back(temp); sort(temp.begin(), temp.end()); sorted.push_back(temp); vmin.push_back(INF), vmax.push_back(-INF), vadd.push_back(0); } } // relax delay inline void setsort(int id) { if (id >= array.size()) return; sorted[id] = array[id]; sort(sorted[id].begin(), sorted[id].end()); } inline void relax(int id) { if (id >= array.size()) return; for (int i = 0; i < array[id].size(); ++i) { chmin(array[id][i], vmin[id]); chmax(array[id][i], vmax[id]); array[id][i] += vadd[id]; } vmin[id] = INF, vmax[id] = -INF, vadd[id] = 0; } // chmin a、chmax b、add c inline void act(int id, int i, long long a, long long b, long long c) { if (id >= array.size() || i >= array[id].size()) return; chmin(array[id][i], a); chmax(array[id][i], b); array[id][i] += c; } // chmin a、chmax b、add c inline void delay(int id, long long a, long long b, long long c) { if (id >= array.size()) return; chmin(vmin[id], a - vadd[id]); chmin(vmax[id], a - vadd[id]); chmax(vmax[id], b - vadd[id]); vadd[id] += c; } // chmin a、chmax b、add c void update(int left, int right, long long a, long long b, long long c) { int lq = left / SIZE, lr = left % SIZE; int rq = right / SIZE, rr = right % SIZE; if (lq == rq) { relax(lq); for (int i = lr; i < rr; ++i) act(lq, i, a, b, c); setsort(lq); return; } relax(lq), relax(rq); for (int i = lr; i < array[lq].size(); ++i) act(lq, i, a, b, c); for (int i = 0; i < rr; ++i) act(rq, i, a, b, c); setsort(lq), setsort(rq); for (int id = lq + 1; id < rq; ++id) delay(id, a, b, c); } // getter inline long long get(int id, long long v) { v -= vadd[id]; if (vmax[id] > v) return 0; if (vmin[id] <= v) return array[id].size(); return upper_bound(sorted[id].begin(), sorted[id].end(), v) - sorted[id].begin(); } long long get(int left, int right, long long v) { int lq = left / SIZE, lr = left % SIZE; int rq = right / SIZE, rr = right % SIZE; long long res = 0; if (lq == rq) { relax(lq), setsort(lq); for (int i = lr; i < rr; ++i) if (array[lq][i] <= v) ++res; return res; } relax(lq), relax(rq), setsort(lq), setsort(rq); for (int i = lr; i < array[lq].size(); ++i) if (array[lq][i] <= v) ++res; for (int i = 0; i < rr; ++i) if (array[rq][i] <= v) ++res; for (int id = lq + 1; id < rq; ++id) res += get(id, v); return res; } }; int main() { int N, Q; scanf("%d %d", &N, &Q); vector<long long> a(N); for (int i = 0; i < N; ++i) cin >> a[i]; SqrtDecomp sd(a); for (int q = 0; q < Q; ++q) { int type, left, right; long long x, y; scanf("%d %d %d %lld", &type, &left, &right, &x); if (type == 1) sd.update(left-1, right, x, -INF, 0); else if (type == 2) sd.update(left-1, right, INF, x, 0); else if (type == 3) sd.update(left-1, right, INF, -INF, x); else { scanf("%lld", &y); cout << sd.get(left-1, right, y) - sd.get(left-1, right, x-1) << endl; } } }