けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

SPOJ DQUERY - D-query

種類数クエリをマスターするぞ!

問題へのリンク

問題概要

 N 要素の整数列  a_1, a_2, \dots, a_N が与えられる。以下の  Q 個のクエリに答えよ:

  • 数列の区間 [  l, r ) 内に何種類の整数があるかを答えよ

制約

  •  1 \le N \le 30000
  •  1 \le Q \le 2000000
  •  1 \le a_i \le 10^{6}

解法 1: BIT (区間加算、1 点取得)

まずは BIT を用いる方法から。
考えやすいように、クエリを「区間の右端」でソートしておく。そうすると、以下のことが成立する。


  • クエリ  l, r に対して、区間 [0, r) の範囲のみについて、数列中の同じ値同士を線で結ぶ (左端に任意の数値があるものとする)

  • このとき、l のところを横切る線の本数が求める「種類数」である


f:id:drken1215:20200116165836p:plain

実装上は、双対 BIT (区間加算、1 点取得) を用いて

  • マス l とマス r が等しいことを検知したら区間 [l + 1, r) に 1 を加算
    • ここで、左端に置いている「任意の値」を表す番兵マスの index は l = -1 とする
  • クエリ [l, r) に対しては、[l, l+1) の値を取得する

という風にすれば OK。

f:id:drken1215:20200116170716p:plain

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

template <class Abel> struct BIT {
    vector<Abel> dat[2];
    Abel UNITY_SUM = 0;                     // to be set
    
    /* [1, n] */
    BIT(int n) { init(n); }
    void init(int n) { for (int iter = 0; iter < 2; ++iter) dat[iter].assign(n + 1, UNITY_SUM); }
    
    /* a, b are 1-indexed, [a, b) */
    inline void sub_add(int p, int a, Abel x) {
        for (int i = a; i < (int)dat[p].size(); i += i & -i)
            dat[p][i] = dat[p][i] + x;
    }
    inline void add(int a, int b, Abel x) {
        sub_add(0, a, x * -(a - 1)); sub_add(1, a, x); sub_add(0, b, x * (b - 1)); sub_add(1, b, x * (-1));
    }
    
    /* a is 1-indexed, [a, b) */
    inline Abel sub_sum(int p, int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i) res = res + dat[p][i];
        return res;
    }
    inline Abel sum(int a, int b) {
        return sub_sum(0, b - 1) + sub_sum(1, b - 1) * (b - 1) - sub_sum(0, a - 1) - sub_sum(1, a - 1) * (a - 1);
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat[0].size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

int main() {
    int N; cin >> N;
    vector<int> a(N);
    for (int i = 0; i < N; ++i) cin >> a[i];

    int Q; cin >> Q;
    vector<int> lefts(Q), rights(Q), ids(Q);
    for (int i = 0; i < Q; ++i) {
        cin >> lefts[i] >> rights[i];
        --lefts[i];
        ids[i] = i;
    }
    sort(ids.begin(), ids.end(), [&](int i, int j) {
            return rights[i] < rights[j];});

    BIT<int> bit(N+5);
    vector<int> prev(1100000, -1);
    vector<int> res(Q, 0);
    int r = 0;
    for (auto i : ids) {
        for (; r < rights[i]; ++r) {
            bit.add(prev[a[r]]+2, r+2, 1);
            prev[a[r]] = r;
        }
        int tmp = bit.sum(lefts[i]+1, lefts[i]+2);
        res[i] = max(res[i], tmp);
    }
    for (int i = 0; i < Q; ++i) cout << res[i] << endl;
}

解法 2: Mo 法

うしさんの記事の写経。Mo 法を手に入れた!!!

ライブラリの使い方としては、

int A[300001];
int res[200001];
int cnt[1000001];
int num_kind = 0;

void Mo::insert(int id) {
    int val = A[id];
    if (cnt[val] == 0) ++num_kind;
    ++cnt[val];
}

void Mo::erase(int id) {
    int val = A[id];
    --cnt[val];
    if (cnt[val] == 0) --num_kind;
}

の部分を毎回個別に書く。区間に対する情報を、

  • insert(id): 数 A[id] を区間に新たに加えるとどうなるかを更新
  • erase(id): 数 A[id] を区間から除くと加えるとどうなるかを更新
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <cmath>
using namespace std;

struct Mo {
    vector<int> left, right, index; // the interval's left, right, index
    vector<bool> v;
    int window;
    int nl, nr, ptr;
    
    Mo(int n) : window((int)sqrt(n)), nl(0), nr(0), ptr(0), v(n, false) { }
    
    /* push */
    void push(int l, int r) { left.push_back(l), right.push_back(r); }
    
    /* sort intervals */
    void build() {
        index.resize(left.size());
        iota(index.begin(), index.end(), 0);
        sort(begin(index), end(index), [&](int a, int b)
        {
            if (left[a] / window != left[b] / window) return left[a] < left[b];
            return bool((right[a] < right[b]) ^ (left[a] / window % 2));
        });
        
        //        sort(index.begin(), index.end(), [&](int a, int b) {
        //            if (left[a] / window != right[b] / window) return left[a] < left[b];
        //            else return right[a] < right[b];
        //        });
    }
    
    /* extend-shorten */
    void extend_shorten(int id) {
        v[id].flip();
        if (v[id]) insert(id);
        else erase(id);
    }
    
    /* next id of interval */
    int next() {
        if (ptr == index.size()) return -1;
        int id = index[ptr];
        while (nl > left[id]) extend_shorten(--nl);
        while (nr < right[id]) extend_shorten(nr++);
        while (nl < left[id]) extend_shorten(nl++);
        while (nr > right[id]) extend_shorten(--nr);
        return index[ptr++];
    }
    
    /* insert, erase (to be set appropriately) */
    void insert(int id);
    void erase(int id);
};

int N, Q;
int A[300001];
int res[200001];
int cnt[1000001];
int num_kind = 0;

void Mo::insert(int id) {
    int val = A[id];
    if (cnt[val] == 0) ++num_kind;
    ++cnt[val];
}

void Mo::erase(int id) {
    int val = A[id];
    --cnt[val];
    if (cnt[val] == 0) --num_kind;
}

int main() {
    scanf("%d", &N);
    for(int i = 0; i < N; i++) scanf("%d", &A[i]);
    scanf("%d", &Q);
    Mo mo(N);
    for(int i = 0; i < Q; i++) {
        int l, r; scanf("%d %d", &l, &r);
        mo.push(--l, r);
    }
    mo.build();
    for (int i = 0; i < Q; i++){
        int idx = mo.next();
        res[idx] = num_kind;
    }
    for(int i = 0; i < Q; i++) printf("%d\n", res[i]);
}