DP 高速化系問題。こういうのが緑 diff になるようになったんかーー (水色 diff にアップグレードした!)
問題概要
マスからなるマス目が与えられる。また、 個の互いに disjoint な区間 が与えられる。この区間に属する整数からなる集合を とする。
マス 1 からマス まで、次の操作を繰り返して到達したい。そのような操作列の個数を 998244353 で割ったあまりを求めよ。
- に含まれる整数 を 1 個選んで、マスを 個分進める
制約
解法 (1):DP 高速化
0-indexed で考える。この手の DP はとてもよく見る。こんな感じ
- := マス に到達するまでの操作列の個数
そうすると、こんな感じで更新できる。
しかしこのままだと、DP の状態量が だけあって、それぞれについて 通りの遷移を考えることになるので、全体としては の計算量となってしまう。高速化が必要となる。
累積和を用いて DP 高速化
こんなときは累積和を用いて DP 高速化するのは定番ではある。まず、集合 が 個の区間からなることに着目すると、 の部分は、 個の区間についての「区間内の総和」を足し上げたものになることがわかる。
ここで一般に、配列 a の累積和を s としたとき、配列 a の区間 [l, r) の総和は s[r] - s[l] で表せることを思い出そう。よって配列 dp の累積和を sdp とすると、更新式は次のように変形できる。
よって、次のようにすれば OK。計算量は となる。
各 i に対して
- dp[ i ] の値がもとまったら
- sdp[ i + 1 ] = sdp[ i ] + dp[ i ] によって累積和も同時に 1 マス更新する
#include <bits/stdc++.h> using namespace std; // modint template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() const { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; const int MOD = 998244353; using mint = Fp<MOD>; using pint = pair<int,int>; int main() { int N, K; cin >> N >> K; vector<pint> v(K); for (int i = 0; i < K; ++i) cin >> v[i].first >> v[i].second; vector<mint> dp(N, 0), sdp(N + 1, 0); dp[0] = 1, sdp[1] = 1; for (int n = 1; n < N; ++n) { for (auto p : v) { int left = max(0, n - p.second); int right = max(0, n - p.first + 1); dp[n] += sdp[right] - sdp[left]; } sdp[n+1] = sdp[n] + dp[n]; } cout << dp[N - 1] << endl; }
解法 (2):形式的冪級数
今回は区間が 個という特殊制約を用いることで計算量を削減できるパターンだったが、集合 が一般の場合であっても で解くことができる。一般に、部分和問題を扱うような DP は FFT 系統の解法でいい感じに扱えることがよくある気がする。
- の元 に対して、 の係数が 1 であるような多項式
を考えたとき、次のようになる。
- の の係数 = 2 回の操作で ステップ進む場合の数
- の の係数 = 3 回の操作で ステップ進む場合の数
- ...
これらをすべて足したいので、結局
- の の係数
を求めればよいことになる。
#include <bits/stdc++.h> using namespace std; // modint template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() const { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; namespace NTT { long long modpow(long long a, long long n, int mod) { long long res = 1; while (n > 0) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; } long long modinv(long long a, int mod) { long long b = mod, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } u %= mod; if (u < 0) u += mod; return u; } int calc_primitive_root(int mod) { if (mod == 2) return 1; if (mod == 167772161) return 3; if (mod == 469762049) return 3; if (mod == 754974721) return 11; if (mod == 998244353) return 3; int divs[20] = {}; divs[0] = 2; int cnt = 1; long long x = (mod - 1) / 2; while (x % 2 == 0) x /= 2; for (long long i = 3; i * i <= x; i += 2) { if (x % i == 0) { divs[cnt++] = i; while (x % i == 0) x /= i; } } if (x > 1) divs[cnt++] = x; for (int g = 2;; g++) { bool ok = true; for (int i = 0; i < cnt; i++) { if (modpow(g, (mod - 1) / divs[i], mod) == 1) { ok = false; break; } } if (ok) return g; } } int get_fft_size(int N, int M) { int size_a = 1, size_b = 1; while (size_a < N) size_a <<= 1; while (size_b < M) size_b <<= 1; return max(size_a, size_b) << 1; } // number-theoretic transform template<class mint> void trans(vector<mint> &v, bool inv = false) { if (v.empty()) return; int N = (int)v.size(); int MOD = v[0].getmod(); int PR = calc_primitive_root(MOD); static bool first = true; static vector<long long> vbw(30), vibw(30); if (first) { first = false; for (int k = 0; k < 30; ++k) { vbw[k] = modpow(PR, (MOD - 1) >> (k + 1), MOD); vibw[k] = modinv(vbw[k], MOD); } } for (int i = 0, j = 1; j < N - 1; j++) { for (int k = N >> 1; k > (i ^= k); k >>= 1); if (i > j) swap(v[i], v[j]); } for (int k = 0, t = 2; t <= N; ++k, t <<= 1) { long long bw = vbw[k]; if (inv) bw = vibw[k]; for (int i = 0; i < N; i += t) { mint w = 1; for (int j = 0; j < t/2; ++j) { int j1 = i + j, j2 = i + j + t/2; mint c1 = v[j1], c2 = v[j2] * w; v[j1] = c1 + c2; v[j2] = c1 - c2; w *= bw; } } } if (inv) { long long invN = modinv(N, MOD); for (int i = 0; i < N; ++i) v[i] = v[i] * invN; } } // for garner static constexpr int MOD0 = 754974721; static constexpr int MOD1 = 167772161; static constexpr int MOD2 = 469762049; using mint0 = Fp<MOD0>; using mint1 = Fp<MOD1>; using mint2 = Fp<MOD2>; static const mint1 imod0 = 95869806; // modinv(MOD0, MOD1); static const mint2 imod1 = 104391568; // modinv(MOD1, MOD2); static const mint2 imod01 = 187290749; // imod1 / MOD0; // small case (T = mint, long long) template<class T> vector<T> naive_mul (const vector<T> &A, const vector<T> &B) { if (A.empty() || B.empty()) return {}; int N = (int)A.size(), M = (int)B.size(); vector<T> res(N + M - 1); for (int i = 0; i < N; ++i) for (int j = 0; j < M; ++j) res[i + j] += A[i] * B[j]; return res; } // mint template<class mint> vector<mint> mul (const vector<mint> &A, const vector<mint> &B) { if (A.empty() || B.empty()) return {}; int N = (int)A.size(), M = (int)B.size(); if (min(N, M) < 30) return naive_mul(A, B); int MOD = A[0].getmod(); int size_fft = get_fft_size(N, M); if (MOD == 998244353) { vector<mint> a(size_fft), b(size_fft), c(size_fft); for (int i = 0; i < N; ++i) a[i] = A[i]; for (int i = 0; i < M; ++i) b[i] = B[i]; trans(a), trans(b); vector<mint> res(size_fft); for (int i = 0; i < size_fft; ++i) res[i] = a[i] * b[i]; trans(res, true); res.resize(N + M - 1); return res; } vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0); vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0); vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0); for (int i = 0; i < N; ++i) a0[i] = A[i].val, a1[i] = A[i].val, a2[i] = A[i].val; for (int i = 0; i < M; ++i) b0[i] = B[i].val, b1[i] = B[i].val, b2[i] = B[i].val; trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2); for (int i = 0; i < size_fft; ++i) { c0[i] = a0[i] * b0[i]; c1[i] = a1[i] * b1[i]; c2[i] = a2[i] * b2[i]; } trans(c0, true), trans(c1, true), trans(c2, true); static const mint mod0 = MOD0, mod01 = mod0 * MOD1; vector<mint> res(N + M - 1); for (int i = 0; i < N + M - 1; ++i) { int y0 = c0[i].val; int y1 = (imod0 * (c1[i] - y0)).val; int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val; res[i] = mod01 * y2 + mod0 * y1 + y0; } return res; } // long long vector<long long> mul_ll (const vector<long long> &A, const vector<long long> &B) { if (A.empty() || B.empty()) return {}; int N = (int)A.size(), M = (int)B.size(); if (min(N, M) < 30) return naive_mul(A, B); int size_fft = get_fft_size(N, M); vector<mint0> a0(size_fft, 0), b0(size_fft, 0), c0(size_fft, 0); vector<mint1> a1(size_fft, 0), b1(size_fft, 0), c1(size_fft, 0); vector<mint2> a2(size_fft, 0), b2(size_fft, 0), c2(size_fft, 0); for (int i = 0; i < N; ++i) a0[i] = A[i], a1[i] = A[i], a2[i] = A[i]; for (int i = 0; i < M; ++i) b0[i] = B[i], b1[i] = B[i], b2[i] = B[i]; trans(a0), trans(a1), trans(a2), trans(b0), trans(b1), trans(b2); for (int i = 0; i < size_fft; ++i) { c0[i] = a0[i] * b0[i]; c1[i] = a1[i] * b1[i]; c2[i] = a2[i] * b2[i]; } trans(c0, true), trans(c1, true), trans(c2, true); static const long long mod0 = MOD0, mod01 = mod0 * MOD1; vector<long long> res(N + M - 1); for (int i = 0; i < N + M - 1; ++i) { int y0 = c0[i].val; int y1 = (imod0 * (c1[i] - y0)).val; int y2 = (imod01 * (c2[i] - y0) - imod1 * y1).val; res[i] = mod01 * y2 + mod0 * y1 + y0; } return res; } }; // Binomial coefficient template<class T> struct BiCoef { vector<T> fact_, inv_, finv_; constexpr BiCoef() {} constexpr BiCoef(int n) noexcept : fact_(n, 1), inv_(n, 1), finv_(n, 1) { init(n); } constexpr void init(int n) noexcept { fact_.assign(n, 1), inv_.assign(n, 1), finv_.assign(n, 1); int MOD = fact_[0].getmod(); for(int i = 2; i < n; i++){ fact_[i] = fact_[i-1] * i; inv_[i] = -inv_[MOD%i] * (MOD/i); finv_[i] = finv_[i-1] * inv_[i]; } } constexpr T com(int n, int k) const noexcept { if (n < k || n < 0 || k < 0) return 0; return fact_[n] * finv_[k] * finv_[n-k]; } constexpr T fact(int n) const noexcept { if (n < 0) return 0; return fact_[n]; } constexpr T inv(int n) const noexcept { if (n < 0) return 0; return inv_[n]; } constexpr T finv(int n) const noexcept { if (n < 0) return 0; return finv_[n]; } }; // Formal Power Series template <typename mint> struct FPS : vector<mint> { using vector<mint>::vector; // constructor FPS(const vector<mint> &r) : vector<mint>(r) {} // core operator FPS pre(int siz) const noexcept { return FPS(begin(*this), begin(*this) + min((int)this->size(), siz)); } FPS rev() const noexcept { reverse(begin(*this), end(*this)); return *this; } FPS& normalize() noexcept { while (!this->empty() && this->back() == 0) this->pop_back(); return *this; } FPS inv(int deg) const noexcept { assert((*this)[0] != 0); if (deg < 0) deg = (int)this->size(); FPS res({mint(1) / (*this)[0]}); for (int i = 1; i < deg; i <<= 1) { res = (res + res - res * res * this->pre(i << 1)).pre(i << 1); } return res.pre(deg); } // each operator FPS operator - () const noexcept { FPS res = (*this); for (int i = 0; i < (int)res.size(); ++i) res[i] = -res[i]; return res; } FPS operator + (const mint &v) const noexcept { return FPS(*this) += v; } FPS operator + (const FPS& r) const noexcept { return FPS(*this) += r; } FPS operator - (const mint &v) const noexcept { return FPS(*this) -= v; } FPS operator - (const FPS& r) const noexcept { return FPS(*this) -= r; } FPS operator * (const mint &v) const noexcept { return FPS(*this) *= v; } FPS operator * (const FPS& r) const noexcept { return FPS(*this) *= r; } FPS operator / (const mint &v) const noexcept { return FPS(*this) /= v; } FPS operator / (const FPS& r) const noexcept { return FPS(*this) /= r; } FPS operator % (const mint &v) const noexcept { return FPS(*this) %= v; } FPS operator % (const FPS& r) const noexcept { return FPS(*this) %= r; } FPS& operator += (const mint &v) { if (this->empty()) this->resize(1); (*this)[0] += v; return *this; } FPS& operator += (const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); ++i) (*this)[i] += r[i]; return this->normalize(); } FPS& operator -= (const mint &v) { if (this->empty()) this->resize(1); (*this)[0] -= v; return *this; } FPS& operator -= (const FPS &r) { if (r.size() > this->size()) this->resize(r.size()); for (int i = 0; i < (int)r.size(); ++i) (*this)[i] -= r[i]; return this->normalize(); } FPS& operator *= (const mint &v) { for (int i = 0; i < (int)this->size(); ++i) (*this)[i] *= v; return *this; } FPS& operator *= (const FPS &r) { return *this = NTT::mul((*this), r); } FPS& operator /= (const mint &v) { assert(v != 0); mint iv = v.inv(); for (int i = 0; i < (int)this->size(); ++i) (*this)[i] *= iv; return *this; } FPS& operator /= (const FPS &r) { if (this->size() < r.size()) { this->clear(); return *this; } int need = (int)this->size() - (int)r.size() + 1; *this = ((*this).rev().pre(need) * r.rev().inv(need)).pre(need).rev(); return *this; } }; const int MOD = 998244353; using mint = Fp<MOD>; int main() { int N, K; cin >> N >> K; FPS<mint> f(N, 0); for (int i = 0; i < K; ++i) { int l, r; cin >> l >> r; for (int j = l; j <= r; ++j) if (j < N) f[j] = 1; } auto g = -f + 1; auto res = g.inv(N + 10); if (res.size() >= N) cout << res[N-1] << endl; else cout << 0 << endl; }