けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 186 D - Sum of difference (茶色, 400 点)

いろんな方法が考えられそう!

問題概要

 N 個の整数  A_{1}, \dots, A_{N} が与えられる。

 1 \le i \lt j \le N を満たすすべての  (i, j) の組に対する  |A_{i} - A_{j}| の総和を求めよ。

制約

  •  2 \le N \le 2 \times 10^{5}
  •  -10^{8} \le A_{i} \le 10^{8}

考えたこと

絶対値のままだと厄介。ちょっと工夫する。まず、数列  A の並びを入れ替えたとしても答えが変わらないことに注意しよう!!!というわけで、 A を小さい順にソートしてしまうことにする。そうすると、

  •  1 \le i \lt j \le N を満たす  (i, j) についての、 A_{j} - A_{i} の総和

を求める問題になる。絶対値が外れたので少し考えやすくなった。あと、ここから先は 0-indexed で考えることにする。

解法 (1): j を固定

 i, j が両方動くのでは考えづらいので、今回は  j を固定して考えてみることにする。たとえば  j = 4 のときを書き出してみると

  •  A_{4} - A_{0}
  •  A_{4} - A_{1}
  •  A_{4} - A_{2}
  •  A_{4} - A_{3}

の総和を求める問題となる。これを足してみると

 4 A_{4} - (A_{0} + A_{1} + A_{2} + A_{3})

となる。このうち  (A_{0} + A_{1} + A_{2} + A_{3}) の部分は、累積和の形になっている。つまり  A の累積和を  S とすると、

 4 A_{4} - (A_{0} + A_{1} + A_{2} + A_{3}) = 4 A_{4} - S_{4}

と簡潔に表せる。今回は  j = 4 のときを考えたが、一般の  j においては

 j \times A_{j} - S_{j}

となる。これを  j = 0, 1, \dots, N-1 について総和をとれば OK。計算量は  O(N) となる。

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // ソートしておく
    sort(A.begin(), A.end());

    // 累積和
    vector<long long> S(N+1, 0);
    for (int i = 0; i < N; ++i) S[i+1] = S[i] + A[i];

    // 答えを求める
    long long res = 0;
    for (int i = 0; i < N; ++i) res += A[i] * i - S[i];
    cout << res << endl;
}

解法 (2):各要素の寄与度を考える

もう一つの典型解法として、 A_{0}, \dots, A_{N-1} がそれぞれ何回ずつ足したり引いたりするのかを数えるというのがある。

たとえば  A_{0} は、 A_{*} - A_{0} の形で合算される回数が  N-1 回になる。 A_{1} は、 A_{*} - A_{1} の形で合算される回数が  N-2 回であり、 A_{1} - A_{*} の形で合算される回数が  1 回になる。

一般に  A_{i} については

  •  A_{*} - A_{i} の形で合算される回数は、 N-i-1
  •  A_{i} - A_{*} の形で合算される回数は、 i

となるので、全部で  i - (N-i-1) = 2i - N + 1 回となる。よって

 \sum_{i} (2i - N + 1) \times A_{i}

を求めれば OK。計算量は  O(N) となる。

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N;
    cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // ソートしておく
    sort(A.begin(), A.end());

    // 答えを求める
    long long res = 0;
    for (int i = 0; i < N; ++i) {
        res += A[i] * (2*i - N + 1);
    }
    cout << res << endl;
}