けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 256 G - Black and White Stones (黄色, 600 点)

opt さんの得意系って感じだった!

問題概要

一辺の長さが整数  D の正  N 角形がある。

頂点から始めて、周上に距離 1 ごとに黒い石か白い石を置いていく。

石の置き方のうち、各辺上にある白い石の個数が等しくなるようなものの個数を 998244353 で割ったあまりを求めよ。

制約

  •  3 \le N \le 10^{12}
  •  1 \le D \le 10^{4}

考えたこと

頂点をどうするかによって色々分岐が起こる系の問題。

 N 角形をグルッと回っていくことを考えた時、いかにも「最後の頂点を白石にしたか黒石にしたか」を情報に持ちながら DP をしたくなる。頂点を  0, 1, \dots, N-1 と番号をつけることにする。そして、

  • 頂点 0 の色
  • 一辺の白石の個数  d

をそれぞれ最初に固定することにする。各場合を高速に計算できるようにしたい。

  • dp[i][0] ← 頂点  0, 1, \dots, i まで色を塗ったとき、頂点  i が白色になるような塗り方の個数
  • dp[i][1] ← 頂点  0, 1, \dots, i まで色を塗ったとき、頂点  i が黒色になるような塗り方の個数

とする。このとき、 i から  i+1 への遷移は行列演算として表せる。

dp[i+1][0] =  C(D-1, d-2) \times dp[i][0] +  C(D-1, d-1) \times dp[i][1]
dp[i+1][1] =  C(D-1, d-1) \times dp[i][0] +  C(D-1, d) \times dp[i][1]

よって行列累乗で  O(\log N) の計算量で解ける。全体の計算量は  O(D \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

// Binomial coefficient
template<class T> struct BiCoef {
    vector<T> fact_, inv_, finv_;
    constexpr BiCoef() {}
    constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) {
        init(n);
    }
    constexpr void init(int n) noexcept {
        fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1);
        int MOD = fact_[0].getmod();
        for(int i = 2; i < n; i++){
            fact_[i] = fact_[i-1] * i;
            inv_[i] = -inv_[MOD%i] * (MOD/i);
            finv_[i] = finv_[i-1] * inv_[i];
        }
    }
    constexpr T com(int n, int k) const noexcept {
        if (n < k || n < 0 || k < 0) return 0;
        return fact_[n] * finv_[k] * finv_[n-k];
    }
    constexpr T fact(int n) const noexcept {
        if (n < 0) return 0;
        return fact_[n];
    }
    constexpr T inv(int n) const noexcept {
        if (n < 0) return 0;
        return inv_[n];
    }
    constexpr T finv(int n) const noexcept {
        if (n < 0) return 0;
        return finv_[n];
    }
};

// matrix
template<class T> struct Matrix {
    vector<vector<T> > val;
    Matrix(int n = 1, int m = 1, T v = 0) : val(n, vector<T>(m, v)) {}
    void init(int n, int m, T v = 0) {val.assign(n, vector<T>(m, v));}
    void resize(int n, int m) {
        val.resize(n);
        for (int i = 0; i < n; ++i) val[i].resize(m);
    }
    Matrix<T>& operator = (const Matrix<T> &A) {
        val = A.val;
        return *this;
    }
    size_t size() const {return val.size();}
    vector<T>& operator [] (int i) {return val[i];}
    const vector<T>& operator [] (int i) const {return val[i];}
    friend ostream& operator << (ostream& s, const Matrix<T>& M) {
        s << endl;
        for (int i = 0; i < (int)M.size(); ++i) s << M[i] << endl;
        return s;
    }
};

template<class T> Matrix<T> operator * (const Matrix<T> &A, const Matrix<T> &B) {
    Matrix<T> R(A.size(), B[0].size());
    for (int i = 0; i < A.size(); ++i)
        for (int j = 0; j < B[0].size(); ++j)
            for (int k = 0; k < B.size(); ++k)
                R[i][j] += A[i][k] * B[k][j];
    return R;
}

template<class T> Matrix<T> pow(const Matrix<T> &A, long long n) {
    Matrix<T> R(A.size(), A.size());
    auto B = A;
    for (int i = 0; i < A.size(); ++i) R[i][i] = 1;
    while (n > 0) {
        if (n & 1) R = R * B;
        B = B * B;
        n >>= 1;
    }
    return R;
}

template<class T> vector<T> operator * (const Matrix<T> &A, const vector<T> &B) {
    vector<T> v(A.size());
    for (int i = 0; i < A.size(); ++i)
        for (int k = 0; k < B.size(); ++k)
            v[i] += A[i][k] * B[k];
    return v;
}

const int MOD = 998244353;
using mint = Fp<MOD>;
BiCoef<mint> bc(21000);

int main() {
    long long N, D;
    cin >> N >> D;
    
    mint res = 0;
    for (int d = 0; d <= D+1; ++d) {
        // 行列
        Matrix<mint> M(2, 2);
        M[0][0] = bc.com(D-1, d-2);
        M[0][1] = bc.com(D-1, d-1);
        M[1][0] = bc.com(D-1, d-1);
        M[1][1] = bc.com(D-1, d);
        
        // 頂点 0 が白色の場合
        vector<mint> v({1, 0});
        v = pow(M, N-1) * v;
        res += v[0] * bc.com(D-1, d-2) + v[1] * bc.com(D-1, d-1);
        
        // 頂点 0 が黒色の場合
        v = {0, 1};
        v = pow(M, N-1) * v;
        res += v[0] * bc.com(D-1, d-1) + v[1] * bc.com(D-1, d);
    }
    cout << res << endl;
}