かなり時間かかった
問題概要
のみからなる長さ の数列であって、どの連続する部分列に対してもその総和が にならないようなものの個数を 998244353 で割ったあまりを求めよ。
制約
考えたこと
最初は包除原理かな...と思ったけど、どうにもうまくできそうにない。そこで のみであることを活かした解法になりそうだという気持ちになる。まず、 は取り除いてしまっても良くて、そうすると
- のみからなる長さ 以下の数列であって、どの連続する部分列に対してもその総和が にならないようなものの個数 (ただし長さ であれば最後に をかける)
を数え上げる問題ということになる。イメージ的には数直線上で、 から出発して、 または ずつ進んで行くが、途中で を踏んだら を踏んではいけないという問題になる。
少し色々手を動かしてみると、結構簡単に -ジャンプは実現してしまいそうな気持ちになる。例えば
- として
- を踏んだら、 は踏めない。
- そうすると 以下から 以上へと行く手段がない。
という感じ。つまり「1 進む」を一回でもやると自由度が一気に減るのだ。ここで「初めて 1 進むのをやる場所」で場合分けしようという気持ちになる。すなわち
という移動をする場合を考えてみる。まず が偶数のときは自明に < でなければならないが、実は が奇数のときもそうでなければならない。なぜなら をどこかで踏むことになってしまうからだ。
そうすると、実は
という移動 (の途中まで) しかありえないことになる。ここで の間については自由ではある。そして全体として 回以下でなければならない。ここで とおく。
そもそも「1 移動」のジャンプを使わないとき
が偶数だったら 回未満じゃないとダメ。奇数だったら 回のいずれでも OK。 回ジャンプするそれぞれについて、 通りになる。
最終到達点が K-1 より手前のとき
に対して、 回のジャンプで移動する場合を合計することになる。各々の場合は 通りある ( 個の 0 or 1 の合計が ) ので、
となる。このうち のところは累積和などを予め計算しておけば でできるので、 でできる。
最終到達点が K-1 以降のとき
について の移動をする場合の数は 通りである、それぞれについて残り の弾を撃つかどうかを決めることができるので
となる。この計算は でできる。
まとめ
各 に対して で計算できるので全体として でできる。
#include <iostream> #include <vector> using namespace std; template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) v += MOD; } constexpr int getmod() { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept { return is >> x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; const int MOD = 998244353; const int MAX = 6100; using mint = Fp<MOD>; int main() { int N, K; cin >> N >> K; // 二項係数とその累積和 vector<vector<mint> > com(MAX, vector<mint>(MAX, 0)); vector<vector<mint> > scom(MAX, vector<mint>(MAX+1, 0)); com[0][0] = 1; for (int i = 1; i < MAX; ++i) { com[i][0] = 1; for (int j = 1; j < MAX; ++j) com[i][j] = com[i-1][j-1] + com[i-1][j]; } for (int i = 0; i < MAX; ++i) { for (int j = 0; j < MAX; ++j) { scom[i][j+1] = scom[i][j] + com[i][j]; } } mint res = 0; // 2 しかやらない場合 for (int i = 0; i <= N; ++i) { if (K % 2 == 0 && i*2 >= K) continue; res += com[N][i]; } for (int i = 0; i*2+1 < K; ++i) { int L = (K-1) - (i*2+1); mint tmp1 = 0; for (int m = 0; m <= min(L, N - (i+1)); ++m) { tmp1 += scom[m][L-m] * com[N][m+i+1]; } mint tmp2 = 0; for (int m = 0; m <= min(L, N - (i+1)); ++m) { tmp2 += com[m][L-m] * (scom[N][min(N, i*2+1+m)+1] - scom[N][i+1+m]); } res += tmp1 + tmp2; } cout << res << endl; }