けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 141 F - Xor Sum 3 (600 点)

Xor Sum シリーズしゃん。典型てんこ盛り!!!

  • XOR は各桁ごとに独立に考えるとよい
  • XOR に関する問題は mod. 2 での方程式みたいになることも多い
  •  2^{k} > 2^{k-1} + \dots + 2^{0} であることから上位の桁から辞書順的に優先で考えるような貪欲が決まる

などなど。

問題へのリンク

問題概要

 N 個の非負整数  A_1, \dots, A_N が与えられる。これらの値を赤と青に塗り分ける方法のうち、「それぞれの色のグループの XOR 和」の和の最大値を求めよ。

制約

  •  2 \le N \le 10^{5}
  •  0 \le A_i \lt 2^{60}

考えたこと

XOR に関する問題は、多くの場合「各桁ごとに独立に考える」という戦法がはまることが多い。今回も各桁ごとに様子を観察することにする。

まず、 N 個のうち  d 桁目が 1 になっているものが奇数個のとき、どのように 2 グループに分けたとしても、「片方の XOR 和は 1」で「他方の XOR 和は 0」という感じになり、最終的に総和に寄与する分は  1 \times 2^{d} になる。つまり最適化の余地は一切なく、常に一定である。

よって  d 桁目が 1 になっているやつが偶数個であるような  d のみに考察を絞って良い。たとえば 6 個だったとき

  • 1 個と 5 個みたいに (奇数個) | (奇数個) に分割したときは、それぞれ XOR 和は 1 になるので、最終的な結果に寄与する分は  2 \times 2^{d}
  • 2 個と 4 個みたいに (偶数個) | (偶数個) に分割したときは、それぞれ XOR 和は 0 になるので、最終的な結果に寄与する分は  0

という感じになって、最終的な結果に寄与する分は  2 \times 2^{d} 0 かのどちらかになる。

理想的には、1 の個数が偶数個であるようなすべての桁  d について、それを (奇数個) | (奇数個) になるように分割できたらよいが、そうはできないこともある。サンプル 1 がまさにそう。どの桁を優先したらよいか...?

上位桁から辞書順に Greedy

こういう状況で、桁  d による寄与分 2^{d} とかに関係ある量であるとき、最上位桁から Greedy にやると良いみたいなことはよくある!!!

理由は、 d 桁目をとったときに  2^{d} の利得が得られるとして、仮に  d-1, d-2, \dots, 0 桁目ですべて利得が 0 であったとしても、 

 2^{d} \gt 2^{d-1} + 2^{d-2} + \dots + 2^{0} 

が成立することから、 d 桁目の利得を獲得しないよりは絶対に獲得した方がよい!!!!!!!!

mod. 2 での連立方程式

さて、1 の個数が奇数個であるような桁  d について、1 のやつが (奇数個) | (奇数個) に分けられるような条件を上手に数式で表現してみよう。そもそも XOR な問題では mod. 2 の連立方程式が立ちそうという印象は結構ある。

さて、 N 個の整数のうち  i 番目の整数を赤色にぬるとき [tex: x{i} = 1 として青色にぬるときは [tex: x{i} = 0] とするような変数  x_i を用意してみる。そうすると、 d 桁目の 1 のあるやつが (奇数個) と (奇数個) に色分けする条件は、 d 桁目が 1 であるやつを抜き取って  y_1, \dots, y_{M} としたとき、

  •  y_1 + y_2 + \dots + y_M ≡ 1 \pmod 2

という風に綺麗に表現することができる!!!!! これを各桁について方程式を立てると、最大 60 本の mod. 2 上での連立方程式となるわけだ。mod. 2 上での連立方程式の解き方は以下の記事に書いた。

drken1215.hatenablog.com

もちろんすべての桁について方程式を立てたときに解があるとは限らない。しかし、掃き出し法を行うときにダメになった桁 (掃き出した後の拡大係数行列の該当する桁の行が (0 0 0 | 1) となる桁) については、無視すれば OK。ここでなるべく上位の桁を優先的に無視されないようにするために、普段は掃き出し法で pivot を上に運ぶように swap しているところを止めればピッタリ。

よって一回の掃き出し法で解くことができて、bitset 高速化なども行うと、 D = 60 とかして、計算量は  O(D^{2}N / 64) になる。

#include <iostream>
#include <vector>
#include <bitset>
using namespace std;

// Bit Matrix
const int MAX_ROW = 110; // to be set appropriately
const int MAX_COL = 110000; // to be set appropriately
struct BitMatrix {
    int H, W;
    bitset<MAX_COL> val[MAX_ROW];
    BitMatrix(int m = 1, int n = 1) : H(m), W(n) {}
    inline bitset<MAX_COL>& operator [] (int i) {return val[i];}
};

// 掃き出し法
void GaussJordan(BitMatrix &A, bool is_extended = false) {
    vector<bool> used(A.H, 0);
    for (int col = 0; col < A.W; ++col) {
        if (is_extended && col == A.W - 1) break;
        int pivot = -1;
        for (int row = 0; row < A.H; ++row) {
            if (used[row]) continue;
            if (A[row][col]) {
                pivot = row;
                break;
            }
        }
        if (pivot == -1) continue;
        for (int row = 0; row < A.H; ++row) {
            if (row != pivot && A[row][col]) A[row] ^= A[pivot];
        }
        used[pivot] = true;
    }
}

const int MD = 60;
int main() {
    int N; cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // 方程式を立てる
    long long res = 0;
    BitMatrix B(MD+1, N+1);
    vector<bool> cannot(MD+1, 0);
    for (long long d = MD; d >= 0; --d) {
        // d 桁目が 1 が何個か
        int num = 0;
        for (int i = 0; i < N; ++i) {
            if (A[i] & (1LL<<d)) ++num;
        }
        if (num == 0) {
            cannot[d] = 1;
            continue;
        }
        else if (num & 1) {
            cannot[d] = 1;
            res += (1LL<<d);
            continue;
        }

        for (int i = 0; i < N; ++i) {
            if (A[i] & (1LL<<d)) B[MD-d][i] = 1;
        }
        B[MD-d][N] = 1;
    }
    GaussJordan(B, true);

    // 集計
    for (int d = MD; d >= 0; --d) {
        if (cannot[d]) continue;

        // d 行目が (0 0 0 ... 0 | 1) だったらダメ
        bool ok = false;
        for (int i = 0; i < N; ++i) if (B[MD-d][i]) ok = true;
        if (!B[MD-d][N]) ok = true;

        // 結果に 2 × 2^d が寄与
        if (ok) res += (2LL<<d);
    }
    cout << res << endl;
}