けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 3213 Xor Mart (OUPC 2020 E)

半分全列挙 + Binary Trie!!

あるいは、Binary Trie 自体を意識しなくても、上の位から順に桁 DP 的発想で考えていると、それが自然に Binary Trie 上の探索そのものとみなせる!

すごくいい Binary Trie の経験になった!!

問題概要

2 組の  N 個の整数  A_{1}, \dots, A_{N} B_{1}, \dots, B_{N} が与えられる ( A_{i} は 0 以上)。

 1, 2, \dots, N の部分集合  i_{1}, \dots, i_{K} であって、 A_{i_{1}}, \dots, A_{i_{K}} の XOR 和が  M 以下であるもののうち、 B_{i_{1}} + \dots + B_{i_{K}} の最大値を求めよ。

制約

  •  1 \le N \le 34
  •  0 \le M \le 10^{9}

考えたこと

見た目はいかにも半分全列挙!!!
半分全列挙することを考えると、次のような処理を高速にこなせるデータ構造を設計したくなる。

  1. キー  x、値  v の要素を挿入する
  2. 非負整数  A, M が与えられたとき、 x XOR  A \le M を満たすようなキー値  x に対する値の最大値を求める

まず 1 のようなタスクがある時点で、Binary Trie で管理したくなる。Trie で管理しておくと、上位桁から 0 または 1 へと辿っていく処理も明快に書ける。

2 については、通常の Binary Trie でサポートされている処理そのものではないので、Binary Trie 上の探索を考えていくことにする。なお、たとえば

101001****

のように、「**** の部分はなんでもよい」というようなノード v も、Trie では管理できることに注意する。そして、そのようなキー値たちに対する値  v の最大値も、v に持たせることができる。Binary Trie に要素を insert するときに、root からたどる各ノードの最大値を chmax 更新していけば OK。

2 のタスクについて

 x XOR  A \le M を満たすような整数  x をすべて探索することを考えてみよう。上位桁から考えて行って、現在  d 桁目のノード v を考えているとしよう。次のようにすることで、処理することができる。

  •  d = -1 となったら、ノード v の値を返せば OK
  •  A d 桁目が 0 のとき
    •  M d 桁目が 0 ならば、v->left について再帰的に処理する
    •  M d 桁目が 1 ならば、以下の大きい方を採用
      •  x d 桁目が 0 の場合を考えると、残りは好きにしてよいので、v->left の値 (累積 max) を参照する
      •  x d 桁目が 1 の場合を考えると、v->right について再帰的に処理する
  •  A d 桁目が 1 のとき
    •  M d 桁目が 0 ならば、v->right について再帰的に処理する
    •  M d 桁目が 1 ならば、以下の大きい方を採用
      •  x d 桁目が 1 の場合を考えると、残りは好きにしてよいので、v->right の値 (累積 max) を参照する
      •  x d 桁目が 0 の場合を考えると、v->left について再帰的に処理する

これらの計算量は、最大桁数を  D として、 O(D) となる。

半分全列挙全体の計算量

次のようになる。

  •  O(2^{\frac{N}{2}}) 個の要素を Trie に挿入: O(DN2^{\frac{N}{2}})
  •  O(2^{\frac{N}{2}}) 通りの Binary Trie 上のクエリ応答処理: O(DN2^{\frac{N}{2}})

コード

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }
const long long INF = 1LL<<60;

// Binary Trie
template<typename INT, size_t MAX_DIGIT> struct BinaryTrie {
    struct Node {
        size_t count;
        Node *prev, *left, *right;
        long long happy;
        Node(Node *prev, long long h) : count(0), prev(prev), left(nullptr), right(nullptr), happy(h) {}
    };
    Node *root;

    // constructor
    BinaryTrie() : root(emplace(nullptr, -INF)) {}
    inline size_t get_count(Node *v) const { return v ? v->count : 0; }

    // insert
    inline Node* emplace(Node *prev, long long happy) {
        return new Node(prev, happy);
    }
    void insert(INT val, long long happy, size_t k = 1) {
        INT nval = val;
        Node *v = root;
        for (int i = MAX_DIGIT-1; i >= 0; --i) {
            chmax(v->happy, happy);
            bool flag = (nval >> i) & 1;
            if (flag && !v->right) v->right = emplace(v, happy);
            if (!flag && !v->left) v->left = emplace(v, happy);
            if (flag) v = v->right;
            else v = v->left;
        }
        v->count += k;
        chmax(v->happy, happy);
        while ((v = v->prev)) v->count = get_count(v->left) + get_count(v->right);
    }

    // query (val^x が M 以下となる val すべてについての happy の最大値)
    long long query(INT val, INT M, int depth, Node *v) {
        if (!v) return -INF;
        if (depth == -1) return v->happy;
        long long res = -INF;
        bool fval = (val >> depth) & 1;
        bool fM = (M >> depth) & 1;
        if (!fval) {
            if (!fM) chmax(res, query(val, M, depth-1, v->left));
            else {
                if (v->left) chmax(res, v->left->happy);
                chmax(res, query(val, M, depth-1, v->right));
            }
        }
        else {
            if (!fM) chmax(res, query(val, M, depth-1, v->right));
            else {
                if (v->right) chmax(res, v->right->happy);
                chmax(res, query(val, M, depth-1, v->left));
            }
        }  
        return res;
    }
    long long query(INT val, INT M) {
        return query(val, M, MAX_DIGIT-1, root);
    }

    // debug
    void print(Node *v, string prefix = "") {
        if (!v) return;
        cout << prefix << ": " << v->happy << endl;
        print(v->left, prefix + "0");
        print(v->right, prefix + "1");
    }
    void print() {
        print(root);
    }
    vector<pair<INT,long long>> eval(Node *v, int digit) const {
        vector<pair<INT,long long>> res;
        if (!v) return res;
        if (!v->left && !v->right) {
            for (int i = 0; i < get_count(v); ++i) res.emplace_back(0, v->happy);
            return res;
        }
        const auto& left = eval(v->left, digit-1);
        const auto& right = eval(v->right, digit-1);
        for (auto val : left) res.push_back(val);
        for (auto val : right) res.emplace_back(val.first + (INT(1)<<digit), val.second);
        return res;
    }
    vector<pair<INT,long long>> eval() const {
        auto res = eval(root, MAX_DIGIT-1);
        return res;
    }
    friend ostream& operator << (ostream &os,
                                 const BinaryTrie<INT, MAX_DIGIT> &bt) {
        auto res = bt.eval();
        for (auto val : res) os << val << " ";
        return os;
    }
};


int main() {
    long long N, M;
    cin >> N >> M;
    vector<long long> A(N), B(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    for (int i = 0; i < N; ++i) cin >> B[i];

    BinaryTrie<long long, 35> bt;
    int num = N - N/2;
    for (int bit = 0; bit < (1<<num); ++bit) {
        long long asum = 0, bsum = 0;
        for (int i = 0; i < num; ++i) {
            if (bit & (1<<i)) asum ^= A[i+N/2], bsum += B[i+N/2];
        }
        bt.insert(asum, bsum);
    }

    long long res = -INF;
    for (int bit = 0; bit < (1<<(N/2)); ++bit) {
        long long asum = 0, bsum = 0;
        for (int i = 0; i < N/2; ++i) {
            if (bit & (1<<i)) asum ^= A[i], bsum += B[i];
        }
        long long tmp = bt.query(asum, M) + bsum;
        chmax(res, tmp);
    }
    cout << res << endl;
}