けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Yandex.Algorithm 2011 Round 2 D - Powerful array (R2200)

Mo のアルゴリズムの練習第三弾!!!!!

問題へのリンク

問題概要

 N 要素の整数列  a_1, a_2, \dots, a_N が与えられる。以下の  Q 個のクエリに答えよ:

  • 数列の区間 [  l, r ) 内について、各数値  s について  s が何個出現しているかを  K_s として表して、 K_s^{2} × s を求め、その総和を求めよ

制約

  •  1 \le N, Q \le 2 × 10^{5}
  •  1 \le a_i \le 10^{6}

考えたこと

Mo を知っていれば、やるだけではある。今回は

  • cnt[i] := 区間内に i が何個あるか
  • sum := Ks2 × s の総和

を毎ターン更新していく。sum の更新は、差分更新する。オーバーフローに注意。

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <cmath>
#include <cstring>
using namespace std;

struct Mo {
    vector<int> left, right, index; // the interval's left, right, index
    vector<bool> v;
    int window;
    int nl, nr, ptr;
    
    Mo(int n) : window((int)sqrt(n)), nl(0), nr(0), ptr(0), v(n, false) { }
    
    /* push */
    void push(int l, int r) { left.push_back(l), right.push_back(r); }
    
    /* sort intervals */
    void build() {
        index.resize(left.size());
        iota(index.begin(), index.end(), 0);
        sort(begin(index), end(index), [&](int a, int b)
             {
                 if(left[a] / window != left[b] / window) return left[a] < left[b];
                 return right[a] < right[b];
             });
    }
    
    /* extend-shorten */
    void extend_shorten(int id) {
        v[id].flip();
        if (v[id]) insert(id);
        else erase(id);
    }
    
    /* next id of interval */
    int next() {
        if (ptr == index.size()) return -1;
        int id = index[ptr];
        while (nl > left[id]) extend_shorten(--nl);
        while (nr < right[id]) extend_shorten(nr++);
        while (nl < left[id]) extend_shorten(nl++);
        while (nr > right[id]) extend_shorten(--nr);
        return index[ptr++];
    }
    
    /* insert, erase (to be set appropriately) */
    void insert(int id);
    void erase(int id);
};


int N, Q;
int A[200100];
int cnt[1000100];
long long sum = 0;
long long res[200100];

void Mo::insert(int id) {
    long long val = A[id];
    sum += (long long)(cnt[val] * 2 + 1) * val;
    ++cnt[val];
}

void Mo::erase(int id) {
    long long val = A[id];
    --cnt[val];
    sum -= (long long)(cnt[val] * 2 + 1) * val;
}

int main() {
    scanf("%d %d", &N, &Q);
    memset(cnt, 0, sizeof(cnt));
    for(int i = 0; i < N; i++) scanf("%d", &A[i]);
    Mo mo(N);
    for(int i = 0; i < Q; i++) {
        int x, y;
        scanf("%d %d", &x, &y);
        mo.push(--x, y);
    }
    mo.build();
    for(int i = 0; i < Q; i++) res[mo.next()] = sum;
    for(int i = 0; i < Q; i++) printf("%lld\n", res[i]);
}