けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Tenka1 2019 F - Banned X (赤色, 800 点)

かなり時間かかった

問題へのリンク

問題概要

 0, 1, 2 のみからなる長さ  N の数列であって、どの連続する部分列に対してもその総和が  X にならないようなものの個数を 998244353 で割ったあまりを求めよ。

制約

  •  1 \le N \le 3000
  •  1 \le K \le 2N

考えたこと

最初は包除原理かな...と思ったけど、どうにもうまくできそうにない。そこで  0, 1, 2 のみであることを活かした解法になりそうだという気持ちになる。まず、 0 は取り除いてしまっても良くて、そうすると

  •  1, 2 のみからなる長さ  N 以下の数列であって、どの連続する部分列に対してもその総和が  K にならないようなものの個数 (ただし長さ  m であれば最後に  C(N, m) をかける)

を数え上げる問題ということになる。イメージ的には数直線上で、 0 から出発して、 1 または  2 ずつ進んで行くが、途中で  i を踏んだら  i+K を踏んではいけないという問題になる。

少し色々手を動かしてみると、結構簡単に  K-ジャンプは実現してしまいそうな気持ちになる。例えば

  •  K = 8 として
  •  4, 5 を踏んだら、 12, 13 は踏めない。
  • そうすると  11 以下から  14 以上へと行く手段がない。

という感じ。つまり「1 進む」を一回でもやると自由度が一気に減るのだ。ここで「初めて 1 進むのをやる場所」で場合分けしようという気持ちになる。すなわち

  •  0, 2, 4, \dots, 2i, 2i+1, \dots

という移動をする場合を考えてみる。まず  K が偶数のときは自明に  2i <  K でなければならないが、実は  K が奇数のときもそうでなければならない。なぜなら  2i + 1 - K をどこかで踏むことになってしまうからだ。

そうすると、実は

  •  0, 2, ..., 2i, (2i+1, ..., K-1), K+1, K+3, ..., 2i+K-1

という移動 (の途中まで) しかありえないことになる。ここで  2i+1, ..., K-1 の間については自由ではある。そして全体として  n 回以下でなければならない。ここで  L = (K-1) - (2i+1) とおく。

そもそも「1 移動」のジャンプを使わないとき

 K が偶数だったら  \frac{K}{2} 回未満じゃないとダメ。奇数だったら  0, 1, \dots, N 回のいずれでも OK。 i 回ジャンプするそれぞれについて、 C(N, i) 通りになる。

最終到達点が K-1 より手前のとき

 l = m, 1, \dots, L-1 に対して、 m = 0, 1, \dots, n - (i+1) 回のジャンプで移動する場合を合計することになる。各々の場合は  C(m, l-m) 通りある ( m 個の 0 or 1 の合計が  l - m) ので、

 \sum_{m = 0}^{N - (i+1)} (\sum_{l = 0}^{L-m-1} C(m, l)) × C(N, i+1+m)

となる。このうち  \sum_{l = 0}^{L-m-1} C(m, l) のところは累積和などを予め計算しておけば  O(1) でできるので、 O(N) でできる。

最終到達点が K-1 以降のとき

 m = 0, 1, \dots, n-(i+1) について  L の移動をする場合の数は  C(m, L-m) 通りである、それぞれについて残り  min(i, n-(i+1)-m) の弾を撃つかどうかを決めることができるので

 \sum_{m = 0}^{N - (2i+1)} C(m, L-m) × (C(N, i+1+m) + C(N, i+2+m) + \dots + C(N, min(i*2+1+m, n))

となる。この計算は  O(N) でできる。

まとめ

 i に対して  O(N) で計算できるので全体として  O(KN) でできる。

#include <iostream>
#include <vector>
using namespace std;


template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};


const int MOD = 998244353;
const int MAX = 6100;
using mint = Fp<MOD>;


int main() {
    int N, K; cin >> N >> K;

    // 二項係数とその累積和
    vector<vector<mint> > com(MAX, vector<mint>(MAX, 0));
    vector<vector<mint> > scom(MAX, vector<mint>(MAX+1, 0));
    com[0][0] = 1;
    for (int i = 1; i < MAX; ++i) {
        com[i][0] = 1;
        for (int j = 1; j < MAX; ++j) com[i][j] = com[i-1][j-1] + com[i-1][j];
    }
    for (int i = 0; i < MAX; ++i) {
        for (int j = 0; j < MAX; ++j) {
            scom[i][j+1] = scom[i][j] + com[i][j];
        }
    }

    mint res =  0;

    // 2 しかやらない場合
    for (int i = 0; i <= N; ++i) {
        if (K % 2 == 0 && i*2 >= K) continue;
        res += com[N][i];
    }
    for (int i = 0; i*2+1 < K; ++i) {
        int L = (K-1) - (i*2+1);
        mint tmp1 = 0;
        for (int m = 0; m <= min(L, N - (i+1)); ++m) {
            tmp1 += scom[m][L-m] * com[N][m+i+1];
        }
        mint tmp2 = 0;
        for (int m = 0; m <= min(L, N - (i+1)); ++m) {
            tmp2 += com[m][L-m] * (scom[N][min(N, i*2+1+m)+1] - scom[N][i+1+m]);
        }
        res += tmp1 + tmp2;
    }
    cout << res << endl;
}