けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 3181 Proper Instructions (HUPC 2020 day3-J)

比較的素直な考察で解ける問題

問題へのリンク

問題概要

umgくんは 1 次元上の座標 0 にいます。今は時刻 0 です。時刻が 1 進むごとに、今いる座標より 1 大きい座標に移動するか、 1 小さい座標に移動するか、その座標にとどまるかという行動ができます。

 N 個の指示が与えられます。  i 個目の指示は、「時刻  T_i には  L_{i} \le x \le R_{i} を満たす座標  x にいなければならない」という指示です。

 N 個の指示の空でない部分集合  S が「適切」であるとは、umgくんが上手く動くことで、 S に含まれるすべての指示に従うことができることをいいます。適切である指示の集合としてあり得るものの個数を 998244353 で割った余りを求めてください。

制約

  •  1 \le N \le 300
  •  1 \le T_{1} \lt \dots \lt T_{N} \le 10^{9}
  •  -10^{9} \le L_{i} \le R_{i} \le 10^{9}

考えたこと

数え上げ問題では、まず判定問題を解くのが定石ではある!

時刻順にソートして、存在可能区間を遷移していけば OK。i 番目の区間で [l, r) の範囲にいることが可能であった場合、j 番目の区間では dt = tj - ti として、[max(l - dt, Lj), min(r + dt, Rj)] の範囲内にいることができる (これが空だったら両立不可)。

DP へ

以上の考察から、とりあえずこんな DP が考えられる

  • dp[ i ][ l ][ r ] := i 番目の区間を考えている段階で移動可能区間が [l, r] であるような状態にするための、区間の選び方の場合の数

このままでは状態量が多すぎる。しかし、この値が正になるような (l, r) の組合せはとても少ない。具体的には、次のようになる。

  • l としてとりうる値は、右へと進みながら時刻を遡ったときに、ある区間の左端になるような点 ( O(N) 個しかない)
  • r としてとりうる値は、左へと進みながら時刻を遡ったときに、ある区間の右端になるような点 ( O(N) 個しかない)

よって状態を圧縮して  O(N^{3}) で解ける。

コード

あんまりきれいじゃないかも...

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
using mint = Fp<998244353>;

int main() {
    int N;
    cin >> N;
    vector<long long> T(N), L(N), R(N);
    for (int i = 0; i < N; ++i) cin >> T[i] >> L[i] >> R[i];

    int M = N*2 + 10;
    vector<vector<mint>> dp(M, vector<mint>(M, 0));
    auto ndp = dp;
    dp[0][0] = 1;
    vector<long long> left({0}), right({0}), nleft, nright;
    long long curT = 0;
    for (int i = 0; i < N; ++i) {
        ndp.assign(M, vector<mint>(M, 0));
        nleft = vector<long long>({-T[i], L[i]});
        nright = vector<long long>({T[i], R[i]});
        for (int j = 0; j < i; ++j) {
            nleft.push_back(L[j] - T[i] + T[j]);
            nright.push_back(R[j] + T[i] - T[j]);
        }
        sort(nleft.begin(), nleft.end());
        sort(nright.begin(), nright.end());
        nleft.erase(unique(nleft.begin(), nleft.end()), nleft.end());
        nright.erase(unique(nright.begin(), nright.end()), nright.end());
        for (int l = 0; l < left.size(); ++l) {
            for (int r = 0; r < right.size(); ++r) {
                if (dp[l][r] == 0) continue;

                // 区間 i を選ばない場合
                int nl = lower_bound(nleft.begin(), nleft.end(), left[l] - T[i] + curT) - nleft.begin();
                int nr = lower_bound(nright.begin(), nright.end(), right[r] + T[i] - curT) - nright.begin();
                ndp[nl][nr] += dp[l][r];

                // 区間 i を選ぶ場合
                long long nL = max(L[i], left[l] - T[i] + curT);
                long long nR = min(R[i], right[r] + T[i] - curT);
                if (nL <= nR) {
                    nl = lower_bound(nleft.begin(), nleft.end(), nL) - nleft.begin();
                    nr = lower_bound(nright.begin(), nright.end(), nR) - nright.begin();
                    ndp[nl][nr] += dp[l][r];
                }
            }
        }
        swap(ndp, dp);
        swap(nleft, left);
        swap(nright, right);
        curT = T[i];
    }
    mint res = 0;
    for (int l = 0; l < left.size(); ++l) {
        for (int r = 0; r < right.size(); ++r) {
            res += dp[l][r];
        }
    }
    cout << res - 1 << endl;
}