半分全列挙 + Binary Trie!!
あるいは、Binary Trie 自体を意識しなくても、上の位から順に桁 DP 的発想で考えていると、それが自然に Binary Trie 上の探索そのものとみなせる!
すごくいい Binary Trie の経験になった!!
問題概要
2 組の 個の整数 、 が与えられる ( は 0 以上)。
の部分集合 であって、 の XOR 和が 以下であるもののうち、 の最大値を求めよ。
制約
考えたこと
見た目はいかにも半分全列挙!!!
半分全列挙することを考えると、次のような処理を高速にこなせるデータ構造を設計したくなる。
- キー 、値 の要素を挿入する
- 非負整数 が与えられたとき、 XOR を満たすようなキー値 に対する値の最大値を求める
まず 1 のようなタスクがある時点で、Binary Trie で管理したくなる。Trie で管理しておくと、上位桁から 0 または 1 へと辿っていく処理も明快に書ける。
2 については、通常の Binary Trie でサポートされている処理そのものではないので、Binary Trie 上の探索を考えていくことにする。なお、たとえば
101001****
のように、「**** の部分はなんでもよい」というようなノード v も、Trie では管理できることに注意する。そして、そのようなキー値たちに対する値 の最大値も、v に持たせることができる。Binary Trie に要素を insert するときに、root からたどる各ノードの最大値を chmax 更新していけば OK。
2 のタスクについて
XOR を満たすような整数 をすべて探索することを考えてみよう。上位桁から考えて行って、現在 桁目のノード v を考えているとしよう。次のようにすることで、処理することができる。
- となったら、ノード v の値を返せば OK
- の 桁目が 0 のとき
- の 桁目が 0 ならば、v->left について再帰的に処理する
- の 桁目が 1 ならば、以下の大きい方を採用
- の 桁目が 0 の場合を考えると、残りは好きにしてよいので、v->left の値 (累積 max) を参照する
- の 桁目が 1 の場合を考えると、v->right について再帰的に処理する
- の 桁目が 1 のとき
- の 桁目が 0 ならば、v->right について再帰的に処理する
- の 桁目が 1 ならば、以下の大きい方を採用
- の 桁目が 1 の場合を考えると、残りは好きにしてよいので、v->right の値 (累積 max) を参照する
- の 桁目が 0 の場合を考えると、v->left について再帰的に処理する
これらの計算量は、最大桁数を として、 となる。
半分全列挙全体の計算量
次のようになる。
- 個の要素を Trie に挿入:
- 通りの Binary Trie 上のクエリ応答処理:
コード
#include <bits/stdc++.h> using namespace std; template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; } template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; } const long long INF = 1LL<<60; // Binary Trie template<typename INT, size_t MAX_DIGIT> struct BinaryTrie { struct Node { size_t count; Node *prev, *left, *right; long long happy; Node(Node *prev, long long h) : count(0), prev(prev), left(nullptr), right(nullptr), happy(h) {} }; Node *root; // constructor BinaryTrie() : root(emplace(nullptr, -INF)) {} inline size_t get_count(Node *v) const { return v ? v->count : 0; } // insert inline Node* emplace(Node *prev, long long happy) { return new Node(prev, happy); } void insert(INT val, long long happy, size_t k = 1) { INT nval = val; Node *v = root; for (int i = MAX_DIGIT-1; i >= 0; --i) { chmax(v->happy, happy); bool flag = (nval >> i) & 1; if (flag && !v->right) v->right = emplace(v, happy); if (!flag && !v->left) v->left = emplace(v, happy); if (flag) v = v->right; else v = v->left; } v->count += k; chmax(v->happy, happy); while ((v = v->prev)) v->count = get_count(v->left) + get_count(v->right); } // query (val^x が M 以下となる val すべてについての happy の最大値) long long query(INT val, INT M, int depth, Node *v) { if (!v) return -INF; if (depth == -1) return v->happy; long long res = -INF; bool fval = (val >> depth) & 1; bool fM = (M >> depth) & 1; if (!fval) { if (!fM) chmax(res, query(val, M, depth-1, v->left)); else { if (v->left) chmax(res, v->left->happy); chmax(res, query(val, M, depth-1, v->right)); } } else { if (!fM) chmax(res, query(val, M, depth-1, v->right)); else { if (v->right) chmax(res, v->right->happy); chmax(res, query(val, M, depth-1, v->left)); } } return res; } long long query(INT val, INT M) { return query(val, M, MAX_DIGIT-1, root); } // debug void print(Node *v, string prefix = "") { if (!v) return; cout << prefix << ": " << v->happy << endl; print(v->left, prefix + "0"); print(v->right, prefix + "1"); } void print() { print(root); } vector<pair<INT,long long>> eval(Node *v, int digit) const { vector<pair<INT,long long>> res; if (!v) return res; if (!v->left && !v->right) { for (int i = 0; i < get_count(v); ++i) res.emplace_back(0, v->happy); return res; } const auto& left = eval(v->left, digit-1); const auto& right = eval(v->right, digit-1); for (auto val : left) res.push_back(val); for (auto val : right) res.emplace_back(val.first + (INT(1)<<digit), val.second); return res; } vector<pair<INT,long long>> eval() const { auto res = eval(root, MAX_DIGIT-1); return res; } friend ostream& operator << (ostream &os, const BinaryTrie<INT, MAX_DIGIT> &bt) { auto res = bt.eval(); for (auto val : res) os << val << " "; return os; } }; int main() { long long N, M; cin >> N >> M; vector<long long> A(N), B(N); for (int i = 0; i < N; ++i) cin >> A[i]; for (int i = 0; i < N; ++i) cin >> B[i]; BinaryTrie<long long, 35> bt; int num = N - N/2; for (int bit = 0; bit < (1<<num); ++bit) { long long asum = 0, bsum = 0; for (int i = 0; i < num; ++i) { if (bit & (1<<i)) asum ^= A[i+N/2], bsum += B[i+N/2]; } bt.insert(asum, bsum); } long long res = -INF; for (int bit = 0; bit < (1<<(N/2)); ++bit) { long long asum = 0, bsum = 0; for (int i = 0; i < N/2; ++i) { if (bit & (1<<i)) asum ^= A[i], bsum += B[i]; } long long tmp = bt.query(asum, M) + bsum; chmax(res, tmp); } cout << res << endl; }