けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

CADDi 2018 F - Square (900 点)

面白かった

問題へのリンク

問題概要

 N \times N の盤面の各マスを 0 か 1 かで埋めたい。すでに  M 個のマスについては数字が埋まっている。以下の条件を満たすように残り  N^{2} - M マスを埋める方法は何通りあるか、998244353 で割ったあまりで求めよ。

  • 一辺の長さが 2 以上な部分正方形領域であって、その対角線のうちの左上から右下へのラインが、全体の正方形の対角線 (左上から右下) 上にあるようなものを考える
  • そのような正方形領域すべてについて、それに含まれる 1 の個数が偶数個である (値の和が偶数)

制約

  •  2 \le N \le 10^{5}
  •  M \le 5 \times 10^{4}

まずは判定問題を

この手の「操作によって何通りできるか」を問う問題は、まずは操作によって〜という盤面が作れるかどうかの判定問題を解くと良いと学んだ。それを解く。

まず感じたことは、ほとんどの部分において、対角線を挟んで対称なマス同士には 1 本の制約式がかかるのだ。たとえば、こんな感じ。

f:id:drken1215:20200206100205p:plain

こんな風に、条件で指定される 4 × 4 の正方形に対しては、その右上と左下の総和が偶数、すなわち、右上と左下の値が等しいことがいえる。同じことが 5 × 5 とかにもいえる。一般に以下のことがいえる

  • | i - j | >= 3 のとき、(i, j) と (j, i) は値が等しい

というわけで、もしすでに埋まっている  M マスの中にこの条件を満たさないものがあったらだめ。満たしていても (i, j) と (j, i) のうち少なくとも一方に値があったら一意に決まり、そうでなかったら 2 通りずつ生えることになる。この値は最後に掛け算すれば OK。

残りのマス

3 × 3 については、4 × 4 とは異なる状況が生まれる。それは上図でいうところの「重なる領域」が 1 × 1 になるからだ。そこは偶数とは限らない。でも次のことが言える

  • 3 × 3 の正方形の中央マスが 0 だったら、右上と左下のパリティは等しい
  • 3 × 3 の正方形の中央マスが 1 だったら、右上と左下のパリティは異なる

というわけでまとめると、

  • (i, i), (i-1, j+1), (i+1, j-1) の総和は偶数

という制約とみなすことができる。以上をまとめると、問題の条件は以下の制約として書き下すことができる。逆にこれらを満たせばすべての対角部分正方形についても条件を満たすことがいえる。


  • |i - j| >= 3 なる i, j に対して、(i, j) と (j, i) のパリティは等しい
  • (i, i), (i-1, j+1), (i+1, j-1) の総和は偶数 (i = 1, 2, ..., N-2 に対して)
  • (i, i), (i+1, i), (i, i+1), (i+1, i+1) の総和は偶数 (i = 0, 1, ..., N-2 に対して)

この条件は、どの制約も互いに一次独立になっていることに注意する。

数え上げへ

ここまで特徴付けができれば数え上げできそう。方針を 2 つ考えてみる。

  1. 自由度を数える (制約のランクを数える)
  2. bitDP で遷移していく

1 について、この問題の制約はすべて F2 上の線形方程式なので、その解の個数はかならず 2 のべき乗になる。もちろん Gauss-Jordan の掃き出し法とかやるわけにはいかないが、一次独立な制約式の本数を数えることはできる。それがわかれば  2^{(変数の個数) - (一次独立な制約式の本数)} で答えが求められる。

2 については、この問題は結局は対角線に近い領域を除けば自明な感じになることから、対角線に近いところのみを bitDP で遷移しながらやっていく方針である。

1. 制約のランクを数える

もし  M 個のマス指定がなかったらもうすでに解くことができていて、制約の本数は上にあげた通り、 \frac{N(N-1)}{2} 本なので、変数の個数は  N^{2} 個だから答えは、 \frac{N(N+1)}{2} 通りとなる (サンプル 5 で確かめ可能)。

あとは  M 個のマスの値指定は、それ自体が 1 つずつ新たな制約として加わっていくイメージである。しかしその中には

  • 矛盾するもの
  • 冗長なもの

が出てくる可能性がある。矛盾が発生したらその時点で 0 通り。冗長性のチェックについては、上に掲げた  \frac{N(N-1)}{2} 本の制約それぞれについて「すでに満たしているもの」については「制約式のランク」から引いていけばよい。

注意点として、下のようなパターンだと複合的に矛盾が生じる可能性があるので、それも検出してあげる必要がある。

f:id:drken1215:20200206232618p:plain

こういうのは「制約式から決まるマス」は適宜埋めていきながら扱っていくと良さそう。それで通った。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
const int MOD = 998244353;
using mint = Fp<MOD>;
using pint = pair<int,int>;

long long N, M;
map<pint,int> fi;

mint solve() {
    long long honsuu = N*(N-1)/2 + M;

    // 差 3 以上の対角制約がすでに満たしているか
    set<pint> alr;
    for (auto it : fi) {
        pint p = it.first;
        if (alr.count(p)) continue;
        if (abs(p.first - p.second) >= 3) {
            int v = it.second;
            pint rp = {p.second, p.first};
            if (fi.count(rp)) {
                if (fi[rp] != v) return 0;
                else --honsuu;
            }
            alr.insert(p); alr.insert(rp);
        }
    }
    
    // 斜め 3 制約をあらかじめ埋める
    for (int i = 1; i < N-1; ++i) {
        int sum = 0;
        if (fi.count(pint(i-1,i+1)) && fi.count(pint(i+1,i-1))) {
            int v = fi[pint(i-1,i+1)] ^ fi[pint(i+1,i-1)];
            if (fi.count(pint(i,i))) {
                if (fi[pint(i,i)] != v) return 0;
                else --honsuu;
            }
            else fi[pint(i,i)] = v;
        }
    }

    // 2 × 2 制約
    for (int i = 0; i < N-1; ++i) {
        int sum = 0;
        bool all = true;
        if (fi.count(pint(i,i+1)) && fi.count(pint(i+1,i))) {
            int v = fi[pint(i,i+1)] ^ fi[pint(i+1,i)];
            if (fi.count(pint(i,i)) && fi.count(pint(i+1,i+1))) {
                if ((fi[pint(i,i)] ^ fi[pint(i+1,i+1)]) != v) return 0;
                else --honsuu;
            }
            else if (fi.count(pint(i,i))) fi[pint(i+1,i+1)] = v ^ fi[pint(i,i)];
            else if (fi.count(pint(i+1,i+1))) fi[pint(i,i)] = v ^ fi[pint(i+1,i+1)];
        }
    }
    
    return modpow(mint(2), N*N - honsuu);
}

int main() {
    cin >> N >> M;
    fi.clear();
    for (int i = 0; i < M; ++i) {
        int a, b, c; cin >> a >> b >> c; --a, --b;
        fi[pint(a, b)] = c;
    }
    cout << solve() << endl;
}

2. bitDP で遷移していく

本質となる部分は 2 × 2 の正方形が対角線上に並んでいる部分なので、

  • dp[ i ][ bit ] := (i, i) を左上とする 2 × 2 の正方形については bit な埋め方をするような場合の数 (bit は 24 = 16 通り)

として、遷移していく。それ以外の部分については自由度を前処理で求めておく。

#include <bits/stdc++.h>
using namespace std;

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
const int MOD = 998244353;
using mint = Fp<MOD>;
using pint = pair<int,int>;

long long N, M;
map<pint,int> fi;

mint solve() {
    // 斜めの自由度を先に計算
    long long jiyudo = (N - 1) * (N - 2) / 2;
    set<pint> alr;
    for (auto it : fi) {
        auto p = it.first;
        if (alr.count(p)) continue;
        if (abs(p.first - p.second) >= 3) {
            if (fi.count({p.second, p.first})) {
                if (fi[p] != fi[{p.second, p.first}]) return 0;
            }
            --jiyudo;
        }
        else if (abs(p.first - p.second) == 2) {
            pint q = {p.second, p.first};
            pint c = {(p.first+p.second)/2, (p.first+p.second)/2};
            if (fi.count(p) && fi.count(q)) {
                if (fi.count(c)) {
                    if ((fi[p]+fi[q]+fi[c]) % 2 != 0) return 0;
                }
                fi[c] = fi[p]^fi[q];
            }
            --jiyudo;
        }
        alr.insert(p); alr.insert({p.second, p.first});
    }

    set<int> can;
    for (int bit = 0; bit < (1<<4); ++bit) {
        if (__builtin_popcount(bit) % 2 != 0) continue;
        can.insert(bit);
    }
    vector<vector<mint>> dp(N+1, vector<mint>(1<<4, 0));
    for (auto bit : can) {
        bool ok = true;
        for (int i = 0; i < 4; ++i) {
            pint p(i/2, i%2);
            if (fi.count(p)) if ( ((bit>>i)&1) ^ fi[p] ) ok = false;
        }
        if (ok) dp[0][bit] = 1;
    }
    for (int t = 0; t < N-2; ++t) {
        for (auto bit : can) {
            if (dp[t][bit] == 0) continue;
            for (auto bit2 : can) {
                if ( ((bit>>3)&1) ^ ((bit2>>0)&1) ) continue;
                bool ok = true;
                for (int i = 0; i < 4; ++i) {
                    pint p(t+1+i/2, t+1+i%2);
                    if (fi.count(p)) if (((bit2>>i)&1) ^ fi[p]) ok = false;
                }
                if (ok) dp[t+1][bit2] += dp[t][bit];
            }
        }
    }
    mint res = 0;
    for (int bit = 0; bit < (1<<4); ++bit) res += dp[N-2][bit];
    return res * modpow(mint(2), jiyudo);
}

int main() {
    cin >> N >> M;
    fi.clear();
    for (int i = 0; i < M; ++i) {
        int a, b, c; cin >> a >> b >> c; --a, --b;
        fi[pint(a, b)] = c;
    }
    cout << solve() << endl;
}