よくあるデータ構造問題!! めっちゃ色んな解法がある!
問題概要
長さ の整数列
と整数
が与えられる (0-indexed で表している)。
各 に対して、次の問題に答えてください。
個の整数
を小さい順に並び替えたときの先頭
個の総和を求めよ。
制約
考えたこと
要は次のクエリを高速に処理できるデータ構造があればよい。
- 値
を挿入する
- 値
を削除する (複数個ある場合は 1 個だけ削除する)
- 小さい方から数えて
番目の値を求める
このデータ構造を用いて、「値 を push する」「値
を pop する」という処理 (後述) をそれぞれ実装しておくことで、今回の問題は次のように解ける。
- 最初に
を push する
- 各
に対して
- データ構造中の小さい順に
個の総和を出力する
- 値
を push する
- 値
を pop する
- データ構造中の小さい順に
具体的には、以下のように実現できる。なお、変数 sum
を、データ構造中の小さい順に 個の総和 (
個未満の場合はすべての総和) とする。
値
を push
データ構造に値 を挿入して、
sum += A[i]
とする。
そしてこの操作によって「小さい順に 個以内」から追い出される数
v
があれば、sum -= v
とする。
値
を pop
データ構造から値 を削除する。この値が「小さい順に
個以内」に含まれるならば、
sum -= A[i]
とする。
さらにこの操作によって「小さい順に 個以内」に新たに加わる数
v
があれば、sum += v
とする。
データ構造をどうするか
以上の push / pop においては、冒頭で述べた通り、次の 3 種類のクエリを高速に実行できる必要がある。
- 値
を挿入する
- 値
を削除する (複数個ある場合は 1 個だけ削除する)
- 小さい方から数えて
番目の値を求める
そのようなデータ構造は無数に考えられる。
- multiset
- 削除機能を備えた priority_queue
- BIT 上二分探索の機能を備えた BIT
- セグメント木
- Binary Trie
- Wavelet Matrix
解法 (1):multiset (C++)
最も高等知識を必要としない解法。
番目の値を取得するためには、有名な方法がある。それは
- 小さい順に
個を管理するための multiset:
left
- それ以上の値を管理するための multiset:
right
とを 2 つ持つ戦法だ。値を挿入・削除するときには、うまく left
と right
とのバランスを取るように実装すればよい。なお、multiset 中の最小の要素は
*S.begin()
、最大の要素は *S.rbegin()
で取得できる。
注意点として、C++ の multiset には有名な罠がある。multiset から値
を削除したいときに、
S.erase(x);
と書いてしまうと、 に含まれるすべての
が削除されてしまうのだ。今回は
を 1 個だけ削除したいので注意が必要だ。
正しくは、削除したいイテレータ (ここでは it
とする) を求めて S.erase(it)
のようにする必要がある。具体的には、関数 find()
を用いて、
S.erase(S.find(x));
のように書けばよい!
全体として計算量は となる。
コード (解法 (1))
AC コードを開く
#include <bits/stdc++.h>
using namespace std;
int main() {
// 入力
long long N, M, K;
cin >> N >> M >> K;
vector<long long> A(N);
for (int i = 0; i < N; ++i) cin >> A[i];
// 小さい順に K 個の総和
long long sum = 0;
// multiset
multiset<long long> left, right;
auto push = [&](long long x) -> void {
// とりあえず x を left に挿入する
left.insert(x);
sum += x;
// left のサイズが K を超えるなら left の最大値を right に移す
if (left.size() > K) {
long long y = *left.rbegin();
right.insert(y);
sum -= y;
left.erase(left.find(y));
}
};
auto pop = [&](long long x) -> void {
// とりあえず x を削除する
if (left.count(x)) {
left.erase(left.find(x));
sum -= x;
} else {
right.erase(right.find(x));
}
// left のサイズが K 未満になるなら right の最小値を left に移す
if (left.size() < K) {
long long y = *right.begin();
left.insert(y);
sum += y;
right.erase(right.find(y));
}
};
for (int i = 0; i < M; ++i) push(A[i]);
for (int i = 0; i < N - M + 1; ++i) {
cout << sum << " ";
if (i+M < N) push(A[i+M]), pop(A[i]);
}
cout << endl;
}
解法 (2):削除機能を備えた priority_queue
multiset と同様に、priority_queue を用いて 番目の値を管理するためには、
- 小さい順に
個を管理するための priority_queue (最大の値を取得できるようにする):
left
- それ以上の値を管理するための priority_queue (最小の値を取得できるようにする):
right
を用意する方法が有名だ。
しかし、priority_queue で実現するにあたって一つ障壁がある。それは「priority_queue から値 を削除したいときどうするのか」という部分だ。
通常の priority_queue que
では「最大の値」「最小の値」しか削除できない。この問題を乗り越えるために、削除したい値を予約しておく priority_queue del
を別に用意しておく。そして、que
から「最大の値」を取得するときに、次の疑似コードのように「取得した値が del
で予約された値と一致するかどうかを確認し、一致するならば棄却する」ようにすればよい。
// 削除機能を持たせた priority_queue priority_queue<int> que, del; // 挿入 void push(int x) { que.push(x); } // 削除 void pop(int x) { del.push(x); } // 最大の値を取得 (ここでは最大の値の削除はしない) int get() { while (!del.empty() && que.top() == del.top()) { que.pop(); del.pop(); } return que.top(); }
この方法を活用して、priority_queue を合計 4 個用意することで、この問題は で解ける。
コード (解法 (2))
AC コードを開く
#include <bits/stdc++.h>
using namespace std;
int main() {
// 入力
long long N, M, K;
cin >> N >> M >> K;
vector<long long> A(N);
for (int i = 0; i < N; ++i) cin >> A[i];
// 小さい順に K 個の総和
long long sum = 0;
// proiroty_queue
priority_queue<long long> left, dleft;
priority_queue<long long, vector<long long>, greater<long long>> right, dright;
// 境界値を取得する
auto get_left = [&]() -> long long {
while (!dleft.empty() && left.top() == dleft.top()) {
left.pop(); dleft.pop();
}
return left.top();
};
auto get_right = [&]() -> long long {
while (!dright.empty() && right.top() == dright.top()) {
right.pop(); dright.pop();
}
return right.top();
};
// push と pop
auto push = [&](long long x) -> void {
// とりあえず x を left に挿入する
left.push(x);
sum += x;
// left のサイズが K を超えるなら left の最大値を right に移す
if (left.size() - dleft.size() > K) {
long long y = get_left();
sum -= y;
left.pop();
right.push(y);
}
};
auto pop = [&](long long x) -> void {
// とりあえず x を削除する
if (x <= get_left()) {
dleft.push(x);
sum -= x;
} else {
dright.push(x);
}
// left のサイズが K 未満になるなら right の最小値を left に移す
if (left.size() - dleft.size() < K) {
long long y = get_right();
sum += y;
right.pop();
left.push(y);
}
};
for (int i = 0; i < M; ++i) push(A[i]);
for (int i = 0; i < N - M + 1; ++i) {
cout << sum << " ";
if (i+M < N) push(A[i+M]), pop(A[i]);
}
cout << endl;
}
解法 (3):BIT 上二分探索の機能を備えた BIT
ここから先は高級なデータ構造を使う!
「挿入」「削除」「 番目を取得」クエリを処理するのに BIT (Binary Indexed Tree) を使うとよいケースも多い。ただし、挿入する値は
以上
(程度) 以下の整数でなければならない。BIT 内部で用意する配列
dat
に対して、値 x
を挿入するときは dat[x]
にアクセスするためだ。
そこで、今回の問題では座標圧縮をする。配列 の各値を小さい順に番号付ける。それにより、
の各値は
を並び替えたものになる (今回はタイは潰さない)。これで BIT が使えるようになる。
なお、BIT を用いて具体的に「挿入」「削除」「 番目を取得」を実行する方法については、たとえば次の資料などを参考に。
番目の値を取得する部分については、「BIT 上の二分探索」がキーワード。
コード (解法 (3))
小さい順に 個の総和を求める部分では、これまでのように変数
sum
を用いてもよいが、下コードのように、総和を求めるための専用の BIT 変数をもう 1 つ用意する方が簡便だと思う!
AC コードを開く
#include <bits/stdc++.h>
using namespace std;
// get(k): binary search on BIT
template <class Abel> struct BIT {
Abel UNITY_SUM = 0;
vector<Abel> dat;
// [0, n)
BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
void init(int n) {
dat.assign(n, UNITY_SUM);
}
// a is 0-indexed
inline void add(int a, Abel x) {
for (int i = a; i < (int)dat.size(); i |= i + 1)
dat[i] = dat[i] + x;
}
// [0, a), a is 0-indexed
inline Abel sum(int a) {
Abel res = UNITY_SUM;
for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
res = res + dat[i];
return res;
}
// [a, b), a and b are 0-indexed
inline Abel sum(int a, int b) {
return sum(b) - sum(a);
}
// k-th number (k is 0-indexed)
int get(long long k) {
++k;
int res = 0;
int N = 1;
while (N < (int)dat.size()) N *= 2;
for (int i = N / 2; i > 0; i /= 2) {
if (res + i - 1 < (int)dat.size() && dat[res + i - 1] < k) {
k = k - dat[res + i - 1];
res = res + i;
}
}
return res;
}
// debug
void print() {
for (int i = 0; i < (int)dat.size(); ++i)
cout << sum(i, i + 1) << ",";
cout << endl;
}
};
int main() {
// 入力
long long N, M, K;
cin >> N >> M >> K;
vector<long long> A(N);
for (int i = 0; i < N; ++i) cin >> A[i];
// 座標圧縮を準備
vector<pair<long long,int>> comp(N);
for (int i = 0; i < N; ++i) comp[i] = {A[i], i};
sort(comp.begin(), comp.end());
vector<int> order(N);
for (int i = 0; i < N; ++i) order[comp[i].second] = i;
// BIT
BIT<long long> bit(N + 1), sum(N + 1);
// push と pop
auto push = [&](long long x, int id) -> void {
bit.add(id, 1);
sum.add(id, x);
};
auto pop = [&](long long x, int id) -> void {
bit.add(id, -1);
sum.add(id, -x);
};
auto get = [&]() -> long long {
int kthid = bit.get(K-1);
return sum.sum(kthid+1);
};
for (int i = 0; i < M; ++i) push(A[i], order[i]);
for (int i = 0; i < N - M + 1; ++i) {
cout << get() << " ";
if (i+M < N) {
push(A[i+M], order[i+M]);
pop(A[i], order[i]);
}
}
cout << endl;
}
解法 (4):セグメント木
BIT で解けるならセグメント木でも解ける。解法 (3) で準備した 2 つの BIT を相当して、
- 各ノードの値:(挿入された値の個数, 挿入された値の総和)
- ノード間の演算:(挿入された値の個数の和, 挿入された値の総和の和)
によって定義されたセグメント木を用意する。
「小さい順に 個の総和」を求めるクエリを処理するためには次のようにする。
コード (解法 (4))
ACL を用いた。ACL には、セグメント木上の二分探索を実行する関数 max_right()
も用意されている。
AC コードを開く
#include <bits/stdc++.h>
#include <atcoder/segtree>
using namespace std;
using namespace atcoder;
// セグメント木の設定
using Monoid = pair<long long, long long>;
Monoid op(Monoid a, Monoid b) {
return {a.first + b.first, a.second + b.second};
};
Monoid e(){return {0, 0};}
int main() {
// 入力
long long N, M, K;
cin >> N >> M >> K;
vector<long long> A(N);
for (int i = 0; i < N; ++i) cin >> A[i];
// 座標圧縮を準備
vector<pair<long long,int>> comp(N);
for (int i = 0; i < N; ++i) comp[i] = {A[i], i};
sort(comp.begin(), comp.end());
vector<int> order(N);
for (int i = 0; i < N; ++i) order[comp[i].second] = i;
// セグメント木
segtree<Monoid, op, e> seg(N);
// セグメント木上の二分探索を実行するための関数
auto f = [&](Monoid x) -> bool { return x.first <= K; };
// push と pop
auto push = [&](long long x, int id) -> void {
seg.set(id, Monoid(1, x));
};
auto pop = [&](long long x, int id) -> void {
seg.set(id, e());
};
auto get = [&]() -> long long {
int r = seg.max_right(0, f);
return seg.prod(0, r).second;
};
for (int i = 0; i < M; ++i) push(A[i], order[i]);
for (int i = 0; i < N - M + 1; ++i) {
cout << get() << " ";
if (i+M < N) {
push(A[i+M], order[i+M]);
pop(A[i], order[i]);
}
}
cout << endl;
}
解法 (5):Binary Trie
「挿入」「削除」「 番目の値を取得」ができるデータ構造はまだまだある。
挿入する値が非負整数値ならば、Binary Trie も有力。詳細はここでは省略。
コード (解法 (5))
AC コードを開く
#include <bits/stdc++.h>
using namespace std;
// Binary Trie
template<typename INT, size_t MAX_DIGIT> struct BinaryTrie {
struct Node {
size_t count;
Node *prev, *left, *right;
Node(Node *prev) : count(0), prev(prev), left(nullptr), right(nullptr) {}
};
INT lazy;
Node *root;
// constructor
BinaryTrie() : lazy(0), root(emplace(nullptr)) {}
inline size_t get_count(Node *v) const { return v ? v->count : 0; }
inline size_t size() const { return get_count(root); }
// add and get value of Node
inline void add(INT val) {
lazy ^= val;
}
inline INT get(Node *v) {
if (!v) return -1;
INT res = 0;
for (int i = 0; i < MAX_DIGIT; ++i) {
if (v == v->prev->right)
res |= INT(1)<<i;
v = v->prev;
}
return res ^ lazy;
}
// find Node* whose value is val
Node* find(INT val) {
INT nval = val ^ lazy;
Node *v = root;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
bool flag = (nval >> i) & 1;
if (flag) v = v->right;
else v = v->left;
if (!v) return v;
}
return v;
}
// insert
inline Node* emplace(Node *prev) {
return new Node(prev);
}
void insert(INT val, size_t k = 1) {
INT nval = val ^ lazy;
Node *v = root;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
bool flag = (nval >> i) & 1;
if (flag && !v->right) v->right = emplace(v);
if (!flag && !v->left) v->left = emplace(v);
if (flag) v = v->right;
else v = v->left;
}
v->count += k;
while ((v = v->prev)) v->count = get_count(v->left) + get_count(v->right);
}
// erase
Node* clear(Node *v) {
if (!v || get_count(v)) return v;
delete(v);
return nullptr;
}
bool erase(Node *v, size_t k = 1) {
if (!v) return false;
v->count -= k;
while ((v = v->prev)) {
v->left = clear(v->left);
v->right = clear(v->right);
v->count = get_count(v->left) + get_count(v->right);
}
return true;
}
bool erase(INT val) {
auto v = find(val);
return erase(v);
}
// max (with xor-addition of val) and min (with xor-addition of val)
Node* get_max(INT val = 0) {
INT nval = val ^ lazy;
Node* v = root;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
bool flag = (nval >> i) & 1;
if (!v->right) v = v->left;
else if (!v->left) v = v->right;
else if (flag) v = v->left;
else v = v->right;
}
return v;
}
Node* get_min(INT val = 0) {
return get_max(~val & ((INT(1)<<MAX_DIGIT)-1));
}
// lower_bound, upper_bound
Node* get_cur_node(Node *v, int i) {
if (!v) return v;
Node *left = v->left, *right = v->right;
if ((lazy >> i) & 1) swap(left, right);
if (left) return get_cur_node(left, i+1);
else if (right) return get_cur_node(right, i+1);
return v;
}
Node* get_next_node(Node *v, int i) {
if (!v->prev) return nullptr;
Node *left = v->prev->left, *right = v->prev->right;
if ((lazy >> (i+1)) & 1) swap(left, right);
if (v == left && right) return get_cur_node(right, i);
else return get_next_node(v->prev, i+1);
}
Node* lower_bound(INT val) {
INT nval = val ^ lazy;
Node *v = root;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
bool flag = (nval >> i) & 1;
if (flag && v->right) v = v->right;
else if (!flag && v->left) v = v->left;
else if ((val >> i) & 1) return get_next_node(v, i);
else return get_cur_node(v, i);
}
return v;
}
Node* upper_bound(INT val) {
return lower_bound(val + 1);
}
size_t order_of_val(INT val) {
Node *v = root;
size_t res = 0;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
Node *left = v->left, *right = v->right;
if ((lazy >> i) & 1) swap(left, right);
bool flag = (val >> i) & 1;
if (flag) {
res += get_count(left);
v = right;
}
else v = left;
}
return res;
}
// k-th, k is 0-indexed
Node* get_kth(size_t k, INT val = 0) {
Node *v = root;
if (get_count(v) <= k) return nullptr;
for (int i = MAX_DIGIT-1; i >= 0; --i) {
bool flag = (lazy >> i) & 1;
Node *left = (flag ? v->right : v->left);
Node *right = (flag ? v->left : v->right);
if (get_count(left) <= k) k -= get_count(left), v = right;
else v = left;
}
return v;
}
// debug
void print(Node *v, string prefix = "") {
if (!v) return;
cout << prefix << ": " << v->count << endl;
print(v->left, prefix + "0");
print(v->right, prefix + "1");
}
void print() {
print(root);
}
vector<INT> eval(Node *v, int digit) const {
vector<INT> res;
if (!v) return res;
if (!v->left && !v->right) {
for (int i = 0; i < get_count(v); ++i) res.push_back(0);
return res;
}
const auto& left = eval(v->left, digit-1);
const auto& right = eval(v->right, digit-1);
for (auto val : left) res.push_back(val);
for (auto val : right) res.push_back(val + (INT(1)<<digit));
return res;
}
vector<INT> eval() const {
auto res = eval(root, MAX_DIGIT-1);
for (auto &val : res) val ^= lazy;
return res;
}
friend ostream& operator << (ostream &os,
const BinaryTrie<INT, MAX_DIGIT> &bt) {
auto res = bt.eval();
for (auto val : res) os << val << " ";
return os;
}
};
int main() {
// 入力
long long N, M, K;
cin >> N >> M >> K;
vector<long long> A(N);
for (int i = 0; i < N; ++i) cin >> A[i];
// 小さい順に K 個の総和
long long sum = 0;
// Binary Trie
BinaryTrie<int, 30> left, right;
// push と pop
auto push = [&](long long x) -> void {
// とりあえず x を left に挿入する
left.insert(x);
sum += x;
// left のサイズが K を超えるなら left の最大値を right に移す
if (left.size() > K) {
long long y = left.get(left.get_max());
sum -= y;
left.erase(y);
right.insert(y);
}
};
auto pop = [&](long long x) -> void {
// とりあえず x を削除する
if (x <= left.get(left.get_max())) {
left.erase(x);
sum -= x;
} else {
right.erase(x);
}
// left のサイズが K 未満になるなら right の最小値を left に移す
if (left.size() < K) {
long long y = right.get(right.get_min());
sum += y;
left.insert(y);
right.erase(y);
}
};
for (int i = 0; i < M; ++i) push(A[i]);
for (int i = 0; i < N - M + 1; ++i) {
cout << sum << " ";
if (i+M < N) push(A[i+M]), pop(A[i]);
}
cout << endl;
}
解法 (6):Wavelet Matrix
ユーザ解説 by rsk0315 にて、上位互換問題が Wavelet Matrix で解けることが指摘されている。
実際に Rubikun さんが Wavelet Matrix を活用して、メイン部分をわずか 8 行で実装している!