けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 075 E - Meaningful Mean (青色, 600 点)

じょえちゃんえるから。
「平均値が  K 以上」という条件を見たときにパッと考えつく話がある。

問題へのリンク

問題概要

 N 個の正の整数列  A_{1}, \dots, A_{N} と整数  K が与えられる。

整数列の連続する部分列であって、その平均値が  K 以上であるものが何個あるかを求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

一般に、

 a, b, c, \dots の平均値が  K 以上
 a-K, b-K, c-K, \dots の平均値が  0 以上
 a-K, b-K, c-K, \dots の総和が  0 以上

という言い換えができる。「平均値」に関する条件が、なんと「総和」に関する条件へと早変わりして、滅茶苦茶扱いやすくなるのだ。というわけで、元の問題は次の問題へと言い換えることができる。


 N 個の整数  A_{1} - K, \dots, A_{N} - K の連続する部分列であって、その総和が  0 以上であるものが何個あるかを求めよ。


累積和へ

次に「連続する部分列の総和」と言われたら、累積和が使える。 A_{1} - K, \dots, A_{N} - K の累積和を  S_{0}, S_{1}, \dots, S_{N} とする。そうすると、元の問題はさらに次のように言い換えられる。


 N+1 個の整数  S_{0}, S_{1}, \dots, S_{N} に対して、

  •  S_{j} - S_{i} \ge 0 (⇔  S_{i} \le S_{j})

が成立するような  0 \le i \lt j \le N の組が何個あるかを求めよ。


これはもう、ほとんど転倒数と一緒だ。

転倒数へ

転倒数は

  •  S_{i} \ge S_{j}

を満たすような  0 \le i \lt j \le N の組の個数である。今回は不等号は逆だけど、求め方はほとんど一緒。BIT を使うといい感じにできる!

転倒数の求め方は、

などに詳しく書いてある!!計算量は  O(N\log{N})

コード

#include <bits/stdc++.h>
using namespace std;

template <class Abel> struct BIT {
    const Abel UNITY_SUM = 0;                       // to be set
    vector<Abel> dat;
    
    /* [1, n] */
    BIT(int n) : dat(n + 1, UNITY_SUM) { }
    void init(int n) { dat.assign(n + 1, UNITY_SUM); }
    
    /* a is 1-indexed */
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i += i & -i)
            dat[i] = dat[i] + x;
    }
    
    /* [1, a], a is 1-indexed */
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a; i > 0; i -= i & -i)
            res = res + dat[i];
        return res;
    }
    
    /* [a, b), a and b are 1-indexed */
    inline Abel sum(int a, int b) {
        return sum(b - 1) - sum(a - 1);
    }
    
    /* debug */
    void print() {
        for (int i = 1; i < (int)dat.size(); ++i) cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};

int main() {
    int N;
    long long K;
    cin >> N >> K;
    vector<long long> S(N+1, 0);
    for (int i = 0; i < N; ++i) {
        long long a;
        cin >> a;
        S[i+1] = S[i] + (a - K);
    }
    vector<long long> SS = S;
    sort(SS.begin(), SS.end());
    SS.erase(unique(SS.begin(), SS.end()), SS.end());
    long long res = 0;
    BIT<long long> bit(N+10);
    for (int i = 0; i <= N; ++i) {
        int id = lower_bound(SS.begin(), SS.end(), S[i]) - SS.begin();
        res += bit.sum(id+1);
        bit.add(id+1, 1);
    }
    cout << res << endl;
}