けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

みんなのプロコン 2019 E - Odd Subrectangles (800 点)

すごく面白そうだし、これ考えたかった

問題へのリンク

問題概要

 N ×  M の binary 行列  A が与えられる。

  • 行集合の部分集合  2^{N} 通り
  • 列集合の部分集合  2^{M} 通り

の組であって、 A の各要素のうち、行と列がともに該当する部分集合に含まれるようなものの総和が奇数となっているものを数え上げよ。

制約

  •  1 \le N, M \le 300

考えたこと

問題を読む前に先に TL で掃き出し法という言葉を見てしまったので。。。でも自然な考察をしていったらそうなりそうかな。。。

気持ちとして、1 が固まっていて欲しい気持ちになるので、ひとまず

  • 行や列を入れ替えてもいい

というのは真っ先に思うところである。そうすると掃き出し法っぽいことが成立しないかを疑いたくなる。行 a に対して行 b を加えて

(a, b) -> (a ^ b, b)

としたとき、前者の a と b の部分だけを考えたときの、行の部分集合の組は

(φ, a, b, a ^ b)

になるのに対し、後者も実は

(φ, a ^ b, b, a)

であって全く一緒になる。つまり、掃き出し法の操作を行っても、答えは変わらない。これは行方向だけでなく、列方向についても言える。よって最終結果は

E O
O O

みたいな形になる。この単位行列部分のサイズ (すなわち元の行列  A のランク) を  r として、奇数になるのを数え上げればいい。

最後の詰め

とりあえずサイズ  r単位行列についてだけ考えればいい。それに最後の

  •  2^{N-r} × 2^{M-r}

をかければいい。行を  i 個選ぶとき ( i = 1, 2, \dots, r)

  •  i 個の選び方は、 C(r, i) 通り
  • 列の方は、その  i 個以外についてはどう選んでもよくて  2^{r-i} 通りを最後にかける
  •  i 個から奇数個選ぶ場合の数は  2^{i-1} 通り (これは有名事実 --- 偶数個選ぶ場合の数と奇数個選ぶ場合の数は等しい)

これらを全部掛け合わせればいい。このままでも通せる。そして実はさらに計算するともっと簡単な式になる。

 2^{N+M-2r} \sum_{i = 1}^{r} C(r, i) 2^{r-i} 2^{i-1}
 = 2^{N+M-r-1} \sum_{i = 1}^{r} C(r, i)
 = 2^{N+M-r-1} (2^{r} - 1)
 = 2^{N+M-1} - 2^{N+M-r-1}

と求められる。

#include <iostream>
#include <vector>
#include <bitset>
using namespace std;


const int MAX_ROW = 510; // to be set appropriately
const int MAX_COL = 510; // to be set appropriately
struct BitMatrix {
    int H, W;
    bitset<MAX_COL> val[MAX_ROW];
    BitMatrix(int m = 1, int n = 1) : H(m), W(n) {}
    inline bitset<MAX_COL>& operator [] (int i) {return val[i];}
};

int GaussJordan(BitMatrix &A, bool is_extended = false) {
    int rank = 0;
    for (int col = 0; col < A.W; ++col) {
        if (is_extended && col == A.W - 1) break;
        int pivot = -1;
        for (int row = rank; row < A.H; ++row) {
            if (A[row][col]) {
                pivot = row;
                break;
            }
        }
        if (pivot == -1) continue;
        swap(A[pivot], A[rank]);
        for (int row = 0; row < A.H; ++row) {
            if (row != rank && A[row][col]) A[row] ^= A[rank];
        }
        ++rank;
    }
    return rank;
}


const int MOD = 998244353;
long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

int main() { 
    int N, M; cin >> N >> M;
    BitMatrix A(N, M);
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j < M; ++j) {
            int a; cin >> a;
            if (a) A[i][j] = 1;
        }
    }
    vector<int> res;
    int r = GaussJordan(A);  
    cout << (modpow(2LL, N+M-1, MOD) - modpow(2LL, N+M-r-1, MOD) + MOD) % MOD << endl;
}