けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 212 H - Nim Counting (橙色, 600 点)

コンテスト中にアダマール変換を思い出せたのはよかった!

問題概要

整数  N K 個の整数  A_{1}, \dots, A_{K} が与えられます。次の条件を満たすような Nim をすべて考えます。

  • 山の個数は  1, 2, \dots, N のいずれかである
  • 各山の石の個数は  A_{1}, \dots, A_{K} のいずれかである

このような Nim の盤面は  K + K^{2} + \dots + K^{N} 通り考えられますが、そのうち先手必勝であるものが何通りあるか 998244353 で割ったあまりを求めてください。

制約

  •  1 \le N \le 2 \times 10^{5}
  •  1 \le K \lt 2^{16}
  •  1 \le A_{i} \lt 2^{16}

考えたこと

石の個数を  (x_{1}, \dots, x_{N} としたとき、先手必勝である条件は

 x_{1} ^  x_{2} ^ ... ^  x_{N} \neq 0

となることは有名事実である (Nim)。 \neq 0 になるものは数えにくいので、代わりに  = 0 となるものの個数を求めて全体から引くことにしよう。

さて、 A_{i} \lt 2^{16} の制約から、ひとまず次のような DP を考えたくなる。


dp[n][v] n 個の山を考えたときに、それらの石の山の個数の XOR 和が  v となるような場合の数


ここで  A_{i} の制約から、 v の値としてありうるのは  0, 1, \dots, 2^{16}-1 のいずれかに限られることがわかる (以降、 V = 2^{16} とする)。

さて、DP の遷移を考えると、次のようになる。

dp[n+1][v ^ A[i]] += dp[n][v] (各  i に対して)

しかしこのままでは、 O(NVK) の計算量になる。 N の部分はダブリングによって  \log N にできる (これは比較的典型) としても、それでも全然間に合わない。

添字 XOR 畳み込み

dp が間に合わないとなったとき添字の演算に着目してなんらかの畳み込みとみなすのはよくある気がする。最も多いのは FFT に帰着できるケース。

今回は遷移式の添字に「^」が出ているから、きっと添字 XOR 畳み込みに帰着できそうな気がする。具体的には、 A_{1}, \dots, A_{K} の特性関数を考えてみよう。つまり、

 f(i) = 1 ( i = A_{j} となる  j があるとき)、0 (ないとき)

という整数関数を考えるといい感じになる。 f V 次元ベクトルとみなせる。このとき、関数  g, h に対して、演算子  \cdot を次のように定義しよう。

 (g \cdot h)(n) = \sum_{i {\rm xor} j} g(i)h(j)

このとき、dp[n][i] (f \cdot f \cdot \dots \cdot f)(i) ( f n 回合成したものを考えている) に一致することが言える。よって今回は、上記の関数  f に対して


 f + f^{2} + \dots + f^{N}


を求める問題に帰着された (ここで  f^{i} f i 回合成したものを表すものとする)。

以上より今回の問題は、添字 XOR 畳み込みに帰着されたのであった。

高速アダマール変換

添字 XOR 畳み込みは、愚直に計算すると  O(V^{2}) の計算量がかかってしまう。しかしうまくやると  O(V \log V) の計算量で実現できる。具体的には、

 h = f \cdot g

となっているときに、関数  f, g, h をアダマール変換した結果を  F, G, H とすると、 n = 0, 1, \dots, V-1 に対して

 H(n) = F(n)G(n)

が成立するのだ。つまり、 F, G から  H の計算は  O(V) でできる。まとめると、 f, g から  f \cdot g を計算するのは

  •  f, g をアダマール変換して  F, G とする ( O(V \log V) でできる)
  •  H を計算する ( O(V)
  •  H をアダマール逆変換して  h を求める ( O(V \log V) でできる)

という手順でできることになる。今回の問題でも同様のことができる。

 f + f^{2} + \dots + f^{N} (ここでの冪乗は XOR 畳み込み演算)

をアダマール変換すると

 F + F^{2} + \dots + F^{N} (ここでの冪乗は各点の積)

になるのだ。これを計算して、最後にアダマール逆変換すれば OK。なお、 f + g をアダマール変換すると  F + G になること (線形性) は、アダマール変換の仕組みを考えれば納得できる。

級数和を求める

最後に  F + F^{2} + \dots + F^{N} = F(1 + F + \dots + F^{N-1} を求める方法を考える。等比級数の和なので、等比級数の和の公式を使いたくなる (今回はそれでも OK) のだが、一般には  F に逆元が存在するとは限らないので別の方法を考えてみよう。

たとえば、小さい  N で考えたとき

  •  1 + F + F^{2} + F^{3} + F^{4} + F^{5} = (1 + F)(1 + F^{2} + F^{4})
  •  1 + F + F^{2} + F^{3} + F^{4} + F^{5} + F^{6} = 1 + F(1 + F)(1 + F^{2} + F^{4})

のように変形できることに注目しよう。これによって  (F, N) に関する問題が  (F^{2}, N/2) に関する問題へと帰着されるので、演算回数を  O(\log N) に落とせることがわかった。

全体として計算量は  O(V \log V \log N) になる。これなら間に合う。

コード

#include <bits/stdc++.h>
using namespace std;

// Fast Hadamard Transform
// N must be 2^K for some K
namespace FastHadamardTransform {
    template<class T> void trans(vector<T> &v, bool inv = false) {
        int N = v.size();
        for (int i = 1; i < N; i <<= 1) {
            for (int j = 0; j < N; ++j) {
                if ((j & i) == 0) {
                    auto x = v[j], y = v[j | i];
                    v[j] = x + y;
                    v[j | i] = x - y;
                    if (inv) v[j] /= 2, v[j | i] /= 2;
                }
            }
        }
    }

    template<class T> vector<T> mul(const vector<T> &a, const vector<T> &b) {
        int N = a.size();
        auto A = a, B = b;
        trans(A), trans(B);
        vector<T> C(N);
        for (int i = 0; i < N; ++i) C[i] = A[i] * B[i];
        trans(C, true);
        return C;
    }
};
using namespace FastHadamardTransform;


// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;


// 級数和を求める
vector<mint> operator + (const vector<mint> &f, const vector<mint> &g) {
    vector<mint> res(f.size());
    for (int i = 0; i < f.size(); ++i) {
        res[i] = f[i] + g[i];
    }
    return res;
} 
vector<mint> operator * (const vector<mint> &f, const vector<mint> &g) {
    vector<mint> res(f.size());
    for (int i = 0; i < f.size(); ++i) {
        res[i] = f[i] * g[i];
    }
    return res;
}

// 級数和
// unit: 冪乗演算の単位元
template<class T> T calc_series(T v, long long N, T unit) {
    if (N == 1) return unit;

    if (N % 2 == 1)
        return v * calc_series(v, N - 1, unit) + unit;
    else
        return (v + unit) * calc_series(v * v, N / 2, unit);
}    

// ベクトルサイズ
const int V = 65536;

int main() {
    // 入力
    int N, K;
    cin >> N >> K;
    vector<mint> f(V);
    for (int i = 0; i < K; ++i) {
        int a;
        cin >> a;
        f[a] = 1;
    }

    // 全体の答え
    mint all = mint(K) * calc_series(mint(K), N, mint(1));

    // 級数和を求める
    vector<mint> unit(V, 0);
    unit[0] = 1;
    trans(f), trans(unit);  // 高速アダマール変換
    auto res = f * calc_series(f, N, unit);
    trans(res, true);

    cout << all - res[0] << endl;
}