面白かった。重み付き Union-Find を使った。
問題概要
0 と 1 と ? のみからなる長さ の文字列 が与えられる。先頭の文字が 1 であることが保証されている。
以下の条件を満たす整数の組 () の個数を求めよ。
- はともに回文数である (11 や 101 など)
- XOR を計算したとき、 の '?' 以外の部分については値が一致する
制約
考えたこと
のどちらかの桁数は 桁でなければならない。とりあえず の桁数を 桁に固定して考えることにした。
- が 桁未満のときは、求めた個数を全体に合算
- も 桁であるときには、求めた個数を 1/2 にして全体に加算 ( であるものと であるものが半数ずつになるため)
の桁数 を のそれぞれについて考えることにする。このとき、 や の各桁の値は 次元の 0-1 ベクトルとみなせる。この 個の 0-1 変数は
- の最上位の値は 1
- のうち、 桁目より大きいところは に一致 (? 以外)
- と の 桁目の XOR 和は の該当桁に一致 (? 以外)
- と は回文数
という条件を満たすようにする。でもこれは
- or
- XOR or
という 2 タイプの式なので、F2 体上の重み付き Union-Find で管理できる。「根の値の確定しないグループの個数」を として、 通りと求められる。
コード
の計算量となる。
#include <bits/stdc++.h> using namespace std; // modint template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() const { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept { if (n == 0) return 1; if (n < 0) return modpow(modinv(r), -n); auto t = modpow(r, n / 2); t = t * t; if (n & 1) t = t * r; return t; } friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } return Fp<MOD>(u); } }; const int MOD = 998244353; using mint = Fp<MOD>; // Union-Find template<class Abel> struct UnionFind { const Abel UNITY_SUM = 0; // to be set vector<int> par; vector<Abel> diff_weight; vector<Abel> val; vector<int> onum; // 根が 0 のときの 1 の個数 UnionFind() { } UnionFind(int n) : par(n, -1), diff_weight(n, UNITY_SUM) , val(n, -1), onum(n, 0) {} int root(int x) { if (par[x] < 0) return x; else { int r = root(par[x]); diff_weight[x] ^= diff_weight[par[x]]; return par[x] = r; } } Abel calc_weight(int x) { int rx = root(x); return diff_weight[x]; } bool issame(int x, int y) { return root(x) == root(y); } void set(int x, Abel w) { auto rx = root(x); auto dw = diff_weight[x]; val[rx] = w ^ dw; } bool merge(int x, int y, Abel w = 0) { w ^= calc_weight(x); w ^= calc_weight(y); x = root(x), y = root(y); if (x == y) return false; if (par[x] > par[y]) swap(x, y); // merge technique if (w == 0) onum[x] += onum[y]; else onum[x] += -par[y] - onum[y]; if (val[y] != -1) val[x] = val[y] ^ w; par[x] += par[y]; par[y] = x; diff_weight[y] = w; return true; } Abel diff(int x, int y) { return calc_weight(y) ^ calc_weight(x); } int size(int x) { return -par[root(x)]; } int get_onum(int x) { x = root(x); if (val[x] == -1) return min(onum[x], -par[x] - onum[x]); else if (val[x] == 0) return onum[x]; else return -par[x] - onum[x]; } }; mint solve(const string &S) { mint res = 0; int N = (int)S.size(); for (int M = 1; M <= N; ++M) { UnionFind<int> uf(N+M); auto merge = [&](int i, int j, int w) -> bool { if (uf.issame(i, j)) { if (uf.diff(i, j) != w) return false; else return true; } else { w ^= uf.calc_weight(i); w ^= uf.calc_weight(j); i = uf.root(i), j = uf.root(j); if (uf.val[i] != -1 && uf.val[j] != -1 && (uf.val[i] ^ uf.val[j]) != w) return false; uf.merge(i, j, w); return true; } }; bool ok = true; for (int i = 0; i < N-M; ++i) { if (S[i] == '0') uf.set(i, 0); else if (S[i] == '1') uf.set(i, 1); } uf.set(0, 1); uf.set(N, 1); for (int i = 0; i < N; ++i) if (!merge(i, N-i-1, 0)) ok = false; for (int i = 0; i < M; ++i) if (!merge(i+N, M-i-1 + N, 0)) ok = false; for (int i = N-M; i < N; ++i) { if (S[i] == '0') { if (!merge(i, i+M, 0)) ok = false; } else if (S[i] == '1') { if (!merge(i, i+M, 1)) ok = false; } } if (!ok) continue; mint tmp = 1; for (int v = 0; v < N+M; ++v) { if (uf.root(v) != v) continue; if (uf.val[v] == -1) tmp *= 2; } if (M == N) res += tmp / 2; else res += tmp; } return res; } int main() { string S; cin >> S; cout << solve(S) << endl; }