けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 3212 Intimate Slimes (OUPC 2020 D)

|x-a| + |x-b| + ... + |x-z| を最小にする x が a, b, ..., z のメディアンになる話は有名で、それを拡張すると仕組みがわかった!

問題概要

 N 体のスライムがいて、それぞれの強さは  X_{1}, \dots, X_{N} となっている。以下の操作を行うことができる

  • あるスライムの強さを 1 増加させる (コストは  A)
  • あるスライムの強さを 1 減少させる (コストは  B)

 K = 1, 2, \dots, N に対して、以下の問に答えよ。

「スライム  1, 2, ..., K の強さを等しくするのに必要なコストの最小値を求めよ」

制約

  •  1 \le N \le 10^{5}
  •  1 \le X_{i} \le 10^{9}

考えたこと

もし  A = B であったならば、

 |x - X_{1}| + |x - X_{2}| + \dots + |x - X_{K}]

の最小値を求めよ、という問題と等価になっていた。これは  x がメディアンになることは有名だったりする。以下の記事でも議論している。

drken1215.hatenablog.com

これを参考にすると、一般の  A, B に対しても次のことが言える。ここで、一致させるスライムの強さを  x とする

  •  x X_{1}, \dots, X_{K} のいずれかに一致する場合のみ考えればよい (そうでない場合は左右どちらかに動かすことで「良くできる」 or 「変わらない」と言えるため)
  • さらに具体的には  (i-1) : (K-i) B : A に近い状況となる  i に対して  x = X_{i} が最適となる ( x を左右どちらに動かしても解が良くならないような、バランスのとれた位置となるため)

微妙なズレが怖かったので、実装上は  i = KA/(A+B) まわりを ±10 くらい試した。しっかりした答えはふるやんさんの記事より。

www.creativ.xyz

なお、三分探索による別解もあった模様。

高速化

残る問題は、

  1.  X_{1}, \dots, X_{K} のうち小さい順に  i 番目の整数  x を求める
  2.  X_{j} \lt x なる  j について、 x - X_{j} の総和を求める
  3.  X_{j} \gt x なる  j について、 X_{j} - x の総和を求める

という処理を高速にこなすこと。ここでは BIT 上二分探索を用いることにした。2, 3 についても BIT で管理することで高速に求められる。全部まとめて、計算量は  O(N \log N) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// get(k): binary search on BIT
template <class Abel> struct BIT {
    Abel UNITY_SUM = 0;
    vector<Abel> dat;
    
    // [0, n)
    BIT(int n, Abel unity = 0) : UNITY_SUM(unity), dat(n, unity) { }
    void init(int n) {
        dat.assign(n, UNITY_SUM);
    }
    
    // a is 0-indexed
    inline void add(int a, Abel x) {
        for (int i = a; i < (int)dat.size(); i |= i + 1)
            dat[i] = dat[i] + x;
    }
    
    // [0, a), a is 0-indexed
    inline Abel sum(int a) {
        Abel res = UNITY_SUM;
        for (int i = a - 1; i >= 0; i = (i & (i + 1)) - 1)
            res = res + dat[i];
        return res;
    }
    
    // [a, b), a and b are 0-indexed
    inline Abel sum(int a, int b) {
        return sum(b) - sum(a);
    }

    // k-th number (k is 0-indexed)
    int get(long long k) {
        ++k;
        int res = 0;
        int N = 1;
        while (N < (int)dat.size()) N *= 2;
        for (int i = N / 2; i > 0; i /= 2) {
            if (res + i - 1 < (int)dat.size() && dat[res + i - 1] < k) {
                k = k - dat[res + i - 1];
                res = res + i;
            }
        }
        return res;
    }
    
    // debug
    void print() {
        for (int i = 0; i < (int)dat.size(); ++i)
            cout << sum(i, i + 1) << ",";
        cout << endl;
    }
};



int main() {
    long long N, A, B;
    cin >> N >> A >> B;
    vector<long long> V(N);
    for (int i = 0; i < N; ++i) cin >> V[i];

    // 座標圧縮
    auto alt = V;
    sort(alt.begin(), alt.end());
    alt.erase(unique(alt.begin(), alt.end()), alt.end());    
    int M = alt.size();
    BIT<long long> bit(M+10); // 個数管理用
    BIT<long long> bit2(M+10); // 総和を求める用

    for (long long k = 0; k < N; ++k) {
        int i = lower_bound(alt.begin(), alt.end(), V[k]) - alt.begin();
        bit.add(i, 1);
        bit2.add(i, V[k]);

        long long res = 1LL<<60;
        long long base = k * B / (A+B);
        for (long long left = max(0LL, base-10); left <= min(k, base+10); ++left) {
            long long val = alt[bit.get(left)];
            int i = lower_bound(alt.begin(), alt.end(), val) - alt.begin();
            long long leftsum = val * bit.sum(0, i) - bit2.sum(0, i);
            long long rightsum = bit2.sum(i+1, M+5) - val * bit.sum(i+1, M+5);
            res = min(res, leftsum*A + rightsum*B);
        }
        cout << res << endl;
    }
}