けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #488 (Div. 1) E. Nikita and Order Statistics (R2300)

FFT 勉強シリーズその 2。 うーん、、、これ思いつけるもんなんかいな......

Codeforces 488 DIV1 E - Nikita and Order Statistics

問題概要

 N 要素の整数数列  a_{0}, a_{1}, \dots, a_{N-1} と整数  x が与えられる。数列  a の連続する部分列のうち、「 x 未満の値となっているものが  k 個ある」という条件を満たすものの個数を各  k = 0, 1, 2, \dots, N に対して求めよ。

制約

  •  1 \le N \le 2×10^{5}

解法

 O(N^{2}) で間に合わない問題ではあるのだが、まずは  O(N^{2}) な解法を考えてみる。 素直に全区間について、 x 未満が何個あるかを数えて集計すればよい。 x 未満が何個あるのかは累積和をはじめにとっておくと簡単に計算できる。

さて、例えば

  • A = (3, 7, 5, 1, 4, 6)
  • x = 5

で考えてみる。A の各要素は x より小さいかどうかだけが大事で

A = (1, 0, 0, 1, 1, 0)

みたいな感じになる。この累積和をとると

S = (0, 1, 1, 1, 2, 3, 3)

となる。S から二要素を選んでその差ごとに集計する問題であると言える。さっきの問題と同じように S を個数集計ベクトルに変換してみる:

(1, 3, 1, 2, 0, 0, 0) (それぞれ 0, 1, 2, 3, 4, 5, 6 が何個あるか)

さて、S の二要素 S[i], S[j] との差が 2 であるとはつまり、個数集計ベクトルにおいて差が 2 となっている index 同士の積を足したものになっている!実際

  • S が 0 の要素は 1 個、2 のやつは 1 個で掛けて 1 個
  • S が 1 の要素は 3 個、3 の要素は 2 個で掛けて 6 個

となってこれらを合わせて 7 個となるのだが、これは個数集計ベクトルで見ると明快である。注意点として「差が 0」のところだけややこしいのでそれは無視して考えることにする (差が 1 以上のところだけ求めて後で全体から引けば OK)。

こうして、個数集計ベクトルにおいて「index の差が k 同士の要素の積の総和」を計算する問題になったのだが、まさに FFT がハマる設定である。さっきの問題と同じように、個数集計ベクトルを片方ひっくり返してあげればよい。

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

struct ComplexNumber {
    double real, imag;
    inline ComplexNumber& operator = (const ComplexNumber &c) {real = c.real; imag = c.imag; return *this;}
    friend inline ostream& operator << (ostream &s, const ComplexNumber &c) {return s<<'<'<<c.real<<','<<c.imag<<'>';}
};
inline ComplexNumber operator + (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real + y.real, x.imag + y.imag};
}
inline ComplexNumber operator - (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real - y.real, x.imag - y.imag};
}
inline ComplexNumber operator * (const ComplexNumber &x, const ComplexNumber &y) {
    return {x.real * y.real - x.imag * y.imag, x.real * y.imag + x.imag * y.real};
}
inline ComplexNumber operator * (const ComplexNumber &x, double a) {
    return {x.real * a, x.imag * a};
}
inline ComplexNumber operator / (const ComplexNumber &x, double a) {
    return {x.real / a, x.imag / a};
}

struct FFT {
    static const int MAX = 1<<19;               // must be 2^n
    ComplexNumber AT[MAX], BT[MAX], CT[MAX];

    void DTM(ComplexNumber F[], bool inv) {
        int N = MAX;
        for (int t = N; t >= 2; t >>= 1) {
            double ang = acos(-1.0)*2/t;
            for (int i = 0; i < t/2; i++) {
                ComplexNumber w = {cos(ang*i), sin(ang*i)};
                if (inv) w.imag = -w.imag;
                for (int j = i; j < N; j += t) {
                    ComplexNumber f1 = F[j] + F[j+t/2];
                    ComplexNumber f2 = (F[j] - F[j+t/2]) * w;
                    F[j] = f1;
                    F[j+t/2] = f2;
                }
            }
        }
        for (int i = 1, j = 0; i < N; i++) {
            for (int k = N >> 1; k > (j ^= k); k >>= 1);
            if (i < j) swap(F[i], F[j]);
        }
    }
    
    // C is A*B
    void mult(long long A[], long long B[], long long C[]) {
        for (int i = 0; i < MAX; ++i) AT[i] = {(double)A[i], 0.0};
        for (int i = 0; i < MAX; ++i) BT[i] = {(double)B[i], 0.0};
        
        DTM(AT, false);
        DTM(BT, false);
        
        for (int i = 0; i < MAX; ++i) CT[i] = AT[i] * BT[i];
        
        DTM(CT, true);
        
        for (int i = 0; i < MAX; ++i) {
            CT[i] = CT[i] / MAX;
            C[i] = (long long)(CT[i].real + 0.5);
        }
    }
};

int main() {
    long long n, x;
    cin >> n >> x;
    vector<int> a(n, 0);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        if (a[i] < x) a[i] = 1;
        else a[i] = 0;
    }
    vector<int> S(n+1, 0);
    for (int i = 0; i < n; ++i) S[i+1] = S[i] + a[i];
    
    long long A[FFT::MAX] = {0}, B[FFT::MAX] = {0}, C[FFT::MAX] = {0};
    for (int i = 0; i <= n; ++i) A[S[i]]++;
    for (int i = 0; i <= n; ++i) B[i] = A[n-i];
    FFT f; f.mult(A, B, C);

    vector<long long> res(n+1, 0);
    res[0] = (n+1) * n / 2;
    for (int i = 1; i <= n; ++i) {
        res[i] = C[n-i];
        res[0] -= C[n-i];
    }
    for (int i = 0; i <= n; ++i) {
        cout << res[i];
        if (i != n) cout << " ";
    }
    cout << endl;
}