けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 068 E - Snuke Line (700 点)

こういうのを得意になるぞー!!!
でも AtCoder でこういうデータ構造をしっかり準備する必要がある系は珍しい気がする。

問題へのリンク

問題概要

 M+1 個のマス ( 0, 1, \dots, M) があって、マス上に  N 個の区間がある。

 d = 1, 2, \dots, M に対して、 0, d, 2d, \dots と移動したときに何個の区間を踏むかを答えよ (二度以上同じ区間を踏む場合は 1 個とカウントする)

制約

  •  1 \le N \le 3 × 10^{5}
  •  1 \le M \le 10^{5}

考えたこと

ここまでハッキリと「調和級数の和の計算量の構造になってるよ」と問題文に明記されているのは珍しい気がする。

 d 個おきにぴょんぴょんするときに、同じ区間を二度踏んでしまう場合を扱うのが大変そう。でも少し考えると、

  •  d 個おきに飛ぶとき、長さ  d 以上の区間は必ず踏む

ということがわかる。よって  d を小さい順に見ていくとき、区間を長さが短い順にソートしておくと、区間が左から順番に「絶対に踏むとは限らない」という側へと脱落して行くイメージ!!!!!

なので「脱落区間を重ね合わせたもの」というのを用意しておいて、毎回の  d ごとに、 d-1 から  d になった瞬間に長さ  d-1区間があったならそれが脱落するので、それを「脱落区間を重ね合わせたもの」に足しあわせて行く。

そして「脱落区間」は個別にぴょんぴょんして数えてあげて、それに「脱落していない区間の個数」を足したものが答えになる。

残る問題は「脱落区間を重ね合わせたもの」に「脱落した区間」を足し合わせる部分であるが、これを  O(\log{M}) でできるデータ構造がある。これを用いると一点の値の取得にも  O(\log{M}) かかってしまうので、全体として

  • 区間を長さが短い順にソート:  O(N\log{N})
  • 加算クエリ:  O(N\log{M})
  • ぴょんぴょんクエリ:  O(M(\log{M})^{2})

の計算時間となる。

github.com

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
using namespace std;


template <class Abel> struct BIT {
    vector<Abel> dat[2];
    Abel UNITY_SUM = 0;                     // to be set
    
    /* [1, n] */
    BIT(int n) { init(n); }
    void init(int n) { for (int iter = 0; iter < 2; ++iter) dat[iter].assign(n + 1, UNITY_SUM); }
    
    /* a, b are 1-indexed, [a, b) */
    inline void sub_add(int p, int a, Abel x) {
        for (int i = a; i < (int)dat[p].size(); i += i & -i)
            dat[p][i] = dat[p][i] + x;
    }
    inline void add(int a, int b, Abel x) {
        sub_add(0, a, x * -(a - 1)); sub_add(1, a, x); sub_add(0, b, x * (b - 1)); sub_add(1, b, x * (-1));
    }
    
    /* a is 1-indexed, [a, b) */
    inline Abel sub_sum(int p, int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i) res = res + dat[p][i];
        return res;
    }
    inline Abel sum(int a, int b) {
        return sub_sum(0, b - 1) + sub_sum(1, b - 1) * (b - 1) - sub_sum(0, a - 1) - sub_sum(1, a - 1) * (a - 1);
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat[0].size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};


using pint = pair<int,int>;
int N, M;
vector<pint> inter;

int main() {
    cin >> N >> M;
    inter.resize(N);
    for (int i = 0; i < N; ++i)
        cin >> inter[i].first >> inter[i].second, ++inter[i].second;
    sort(inter.begin(), inter.end(), [&](pint i, pint j) {
            return i.second - i.first < j.second - j.first;});

    BIT<int> bit(M + 10);
    int pos = 0;
    for (int d = 1; d <= M; ++d) {
        while (pos < N && inter[pos].second - inter[pos].first <= d-1) {
            bit.add(inter[pos].first, inter[pos].second, 1);
            ++pos;
        }
        int res = 0;
        for (int i = d; i <= M; i += d) res += bit.sum(i, i+1);
        res += (N - pos);
        cout << res << endl;
    }
}