実家 DP 苦手すぎる。今回は解法を簡単なものにするにあたって、「区間の左端も右端も単調増加と思って良い」というのが、割と効いてる気がする。
問題概要
長さ の数列であって、各要素の値が 以上 未満であるもののうち、以下の 個の条件を満たすものの個数を 998244353 で割ったあまりを求めよ。
- 各条件は区間と値が指定される
- 数列の区間の AND をとった値が、指定された値に一致する
制約
考えたこと
すっごく典型色強めで面白そう!!!
こういうのはとりあえず区間が交差しているところについて考えないといけなくて、ひとまず
- 区間が交差しているところの条件の値が異なっているものについては、結局どちらかだけを考えればよいか、両方とも無視してよくなる
- 条件の値が等しいとこだけ慎重に扱えば良い
といったあたりのことが見えてきた。のだが、この時点でも複雑だな...と思っていた。もっと簡単に
- 各桁ごとに完全に独立に考えて良い
ということを見落としていた。
各桁ごとの問題
各桁ごとに長さ の 0, 1 列を作る方法を考えて、それを掛け算すればよい。具体的には各区間について
- 区間のすべての値が 1
- 区間内に 0 が含まれる
という条件で表されることになる。前者については何も考えなくて OK。いもす法とかによって、あらかじめ 1 にしかならない場所を列挙しておくことはできる。
後者の条件は、包除原理したくなるけど、あんまり上手くいかない。でもこういうのはインライン DP でできるイメージはある。ひとまず、区間を右端順にソートしておく。
インライン DP
この手の区間をどうのこうのする DP、「ある状態がどこまで続くのかを添字に持つ」という風にすればよいイメージがある。こうしてみる
- dp[ i ][ j ] := 数列の [0, i) について 0 か 1 を決める方法であって、区間の右端が [0, i) の内部に収まるようなものについての条件はすべて満たしていて、区間 [j, i) の値はすべて 1 で、a[ j ] の値は 0 であるようなものの個数
としてみる。DP の更新をするとき、右端を i にするような区間が複数個あるとき一瞬迷うのだけど、今回については、そのうちの左端が最も右にあるものだけを考えれば OK。その値を L[ i ] とする。なお、右端を i にするような区間が存在しないときは、便宜的に L[ i ] = -1 と考えて良さそう。このとき、
a[ i ] が 1 でなくてもよいときのみ
- dp[ i + 1 ][ i + 1 ] += dp[ i ][ j ] ・・・ j in [0, i+1)
一般
- dp[ i + 1 ][ j ] += dp[ i ][ j ] ・・・ j in [L[ i + 1 ] + 1, i + 1)
- dp[ i + 1 ][ j ] = 0 ・・・ j in [0, L[ i + 1 ] + 1)
という風に更新できる。2 番目の式については in-place 化できて、3 番目の式も上手に遅延評価セグ木を使えば...という感じ。でも面倒。そこでさらに工夫する。
区間の包含関係を除去
実は今回は、3 番目の「dp の値を 0 にする」というのは不要になる。なぜなら、今回は
- 区間 A が区間 B に含まれているような状態のとき、区間 B は無視してよい
- したがって、区間の右端を単調増加になるように並べたとき、区間の左端も単調増加であると考えて良い
という状態になっているからだ。これを考慮すると、以下の感じで OK。計算量は 。
a[ i ] が 1 でなくてもよいときのみ
- dp[ i + 1 ][ i + 1 ] += dp[ i ][ j ] ・・・ j in [L[ i ] + 1, i+1)
一般 (in-place 化で消える)
- dp[ i + 1 ][ j ] += dp[ i ][ j ] ・・・ j in [0, i + 1)
#Include <bits/stdc++.h> using namespace std; template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; } template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; } template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; const int MOD = 998244353; using mint = Fp<MOD>; struct Section { int x, y, val; Section(int x = 0, int y = 0, int val = 0) : x(x), y(y), val(val) {} }; ostream& operator << (ostream &ss, Section s) { return ss << "(" << s.x << "," << s.y << "," << s.val << ")" << endl; } int N, M, K; vector<Section> secs; mint solve() { sort(secs.begin(), secs.end(), [&](Section a, Section b) { return a.y < b.y; }); mint res = 1; for (int d = 0; d < K; ++d) { vector<int> one(N+1, 0), left(N+1, -1); for (auto s : secs) { if (s.val & (1<<d)) one[s.x]++, one[s.y]--; else chmax(left[s.y], s.x); } for (int i = 0; i < N; ++i) { one[i+1] += one[i]; chmax(left[i+1], left[i]); } vector<mint> dp(N+1, 0), sdp(N+2, 0); dp[0] = 1; sdp[1] = 1; for (int i = 0; i < N; ++i) { if (one[i] == 0) dp[i+1] = sdp[i+1] - sdp[left[i]+1]; sdp[i+2] = sdp[i+1] + dp[i+1]; } mint tmp = sdp[N+1] - sdp[left[N]+1]; res *= tmp; } return res; } int main() { while (scanf("%d %d %d", &N, &K, &M) != EOF) { secs.resize(M); for (int i = 0; i < M; ++i) { scanf("%d %d %d", &secs[i].x, &secs[i].y, &secs[i].val); --secs[i].x; } printf("%lld\n", solve().val); } }