けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 101 F - Robots and Exits (赤色, 900 点)

またしても、in-place DP のいい練習になった!!!
最初は絶望感が漂うのだけど、これも結局「必要条件を列挙したら十分条件になっていた」系な気もする。

問題へのリンク

問題概要

 N 個のロボットと、 M 個の穴が一直線上に並んでいる。ロボットは穴に重なると穴に落ちる。今、以下のような操作を好きな順序で好きなだけ行うことができる。

  • すべてのロボットの座標を +1 する
  • すべてのロボットの座標を -1 する

最終的にすべてのロボットがいずれかの穴に落ちるようにする。各ロボットがどの穴から落ちるかの組合せが何通りあるのかを、1000000007 で割ったあまりを求めよ。

制約

  •  1 \le N, M \le 10^{5}
  •  1 \le 各座標値 \le 10^{9}

考えたこと

座標値が 109 とかなので、まともに探索することは厳しい。たとえば

  • 左に  a だけ進んで...
  • 右に  b だけ進んで...

と一往復する方法を調べれば十分なのかな...とか色々考えそうになったけど、どうもそんなことはいえなさそうだと思った。割と振幅を広げながら反復横跳びした方がよい場合がありうる...

こんなときは、こんなことを考えればいいんだと思う。各ロボットが左右のどちらに落ちるのかで  2^{N} 通り考えられるけど、

  • 実は、「ロボット x が左に落ちて、ロボット y が右に落ちることはありえない」というような、局所的なロボットの関係性だけで全体の条件を記述できるのではないか? (できる)
  • 制約がとてもバラバラに感じるけれど、実はロボットをうまいこと順序を決めれば DP できる形になるのではないか?

紐解いてみる。

ロボットの位置を揃える、そして必要条件列挙へ

まず今回の問題は、「各ロボットと左右の穴との相対位置」のみが重要であって、ロボットの位置自体は重要ではない。よって、ロボットの位置を 0 に揃えると

  • 区間  l_{i}, r_{i} (左側は負、右側は正) が  N 個あって、それぞれについて左端と右端のどちらを選ぶか

というフレームワークの問題になる。そして比較的明らかに分かることとして、

  •  l_{i} \le l_{j} かつ  r_{i} \le r_{j} という位置関係にあるような区間  i, j に対して、 l を左側に、 r を右側に落とすことはできない

ということがわかる。他の 3 パターンはできる。あと区間の位置関係はもう 1 つある (一方が他方を包含する場合) が、こっちは 4 パターンすべてできる。なお、区間が完全に一致するやつは 2 パターンしか実現できないが、これは予め除いてしまうことにする。

f:id:drken1215:20191217225812p:plain

そして実は、区間の位置関係が「交差はしているが包含関係にない」ような位置関係についての禁止パターンをすべて守っていたら、その操作が実現可能であることを示すことができる!!!

具体的には、区間の左端が遅い順にソートしておく。そして

  • 残っている区間のうち最初の「左端選択」な区間 p を考えて、それまでの「右端選択」なやつはあらかじめすべて右側に落としてしまう
    • このとき、禁止パターンを守っているから、p のロボットがうっかり右側に落ちてしまうことはない
    • p よりも先のやつが同時に右に落ちる可能性はあるが、それが本当は左端選択であった...なんてこともありえない
  • そして p を左端に落とす
    • このとき、p 以外の区間が落ちてしまうことはない

という感じで、無事にすべて落とすことができる。

DP へ

上記の考察から、とくに区間ソートが有効そうであることもわかる。ここは試行錯誤したのだが、改めて、左端が小さい順にソートし直すことにした。

  • dp[ i ][ j ] := 最初の i 個の区間について、左端を選択した区間の右端の座標の最小値が j (座標圧縮しておく) であるような場合についての、左右パターンの数

とする。このとき、初期条件は、座標圧縮したときの右端座標の個数を s として

  • dp[ 0 ][ s ] = 1

遷移は、

右端を選ぶ場合

  • dp[ i + 1 ][ j ] += dp[ i ][ j ] (任意の j)

左端を選ぶ場合

i 番目の区間の右端の座標を k (座標圧縮しておく) として、

  • dp[ i + 1 ][ k ] += sum_{j = k+1, k+2, ...} dp[ i ][ j ]

という風になる。これは in-place DP が使える形になっているので BIT 上に DP 配列を載せれば OK。

一応の例外処理

  • 区間が完全に一致するやつは除く
  • 区間のうち左右のどちらかが欠けているやつも除く
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;


template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                       // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.assign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};


const long long INF = 1LL<<55;
using pll = pair<long long,long long>;

// 入力
int N, M;
vector<long long> x, y;

mint solve() {
    // 区間に変換
    vector<pll> lrs;
    vector<long long> uright;
    for (int i = 0; i < N; ++i) {
        int it = lower_bound(y.begin(), y.end(), x[i]) - y.begin();
        if (it == 0 || it == M) continue; // 両端は一意なので除く
        lrs.push_back(pll(y[it-1] - x[i], y[it] - x[i]));
        uright.push_back(y[it] - x[i]);
    }

    // 区間ソート
    sort(lrs.begin(), lrs.end());
    lrs.erase(unique(lrs.begin(), lrs.end()), lrs.end()); // 同一の区間は除く
    sort(uright.begin(), uright.end());
    uright.erase(unique(uright.begin(), uright.end()), uright.end());
    int s = uright.size();

    // BIT 上 DP
    BIT<mint> bit(N+10);
    bit.add(s+1, 1);
    for (auto lr : lrs) {
        int k = lower_bound(uright.begin(), uright.end(), lr.second)
            - uright.begin();
        ++k; // 1-indexed に (BIT 対策)
        mint sum = bit.sum(k+1, s+2);
        bit.add(k, sum);
    }
    return bit.sum(1, s+2);
}

int main() {
    cin >> N >> M;
    x.resize(N); y.resize(M);
    for (int i = 0; i < N; ++i) cin >> x[i];
    for (int i = 0; i < M; ++i) cin >> y[i];
    cout << solve() << endl;
}