けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 077 C - Snuke Festival (ARC 084 C) (緑色, 300 点)

lower_bound の練習に!!! あと、「3 つのものを考えるときは、真ん中を固定して考える」という考え方の典型。

問題概要

3 つの数列 (長さ  N)

  •  a_{0}, a_{1}, \dots, a_{N-1}
  •  b_{0}, b_{1}, \dots, b_{N-1}
  •  c_{0}, c_{1}, \dots, c_{N-1}

が与えられる。各数列から要素  (a_{i}, b_{j}, c_{k}) を選ぶ方法のうち、

 a_{i} \lt b_{j} \lt c_{k}

を満たすものの個数を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

単純にすべてのペア  (a_{i}, b_{j}, c_{k}) を考える方法では  O(N^{3}) の計算量となりますので、工夫が必要です。

まずは考えやすくなるように、各数列は小さい順にソートされているものとしましょう (ソートは  O(N \log N) でできます)。さて、3 つ組について考えるときは、真ん中を固定して考えるのが定石です。つまり、 b_{j} を固定して考えてみましょう。このとき、

  •  a_{0}, a_{1}, \dots, a_{N-1} のうち、 b_{j} 未満の個数を数える ( A_{j} とする)
  •  c_{0}, c_{1}, \dots, c_{N-1} のうち、 b_{j} 以上の個数を数える ( C_{j} とする)

というようにします (下図参照)。 a_{i} \lt b_{j} \lt c_{k} を満たす組の個数は  A_{j} \times C_{j} となります。これを  j = 0, 1, \dots, N-1 について足すことで答えが求められます。

f:id:drken1215:20210225221741p:plain

 A_{j} C_{j} を求める

これらはいずれも二分探索 (特に lower_bound) を使って求められます!

まず  a_{0}, a_{1}, \dots, a_{N-1} のうち、 b_{j} 未満の個数を数えてみましょう。lower_bound を使うと、次のようにして  a_{k} \ge b_{j} を満たす最小の  k が求められます。

int k = lower_bound(a.begin(), a.end(), b[j]) - a.begin();

そしてこのとき、 a_{i} \lt b_{j} を満たす  i の個数は  k になります。こうして  A_{j} = k であることがわかりました。

 C_{j} についても同様に求められます。具体的には、 c_{i} \le b_{j} (不等号に等号が付くことに注意) を満たす  i の個数を求めて、それを  N から引けばよいでしょう。そのような  i の個数は、lower_bound() の代わりに upper_bound() を用いることで求められます。

計算量・コード

 b_{j} に対して、 A_{j} C_{j} O(\log N) で求められます。よって全体の計算量は  O(N \log N) となることがわかりました。

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

int main() {
    // 入力
    int N;
    cin >> N;
    vector<long long> a(N), b(N), c(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    for (int i = 0; i < N; ++i) cin >> b[i];
    for (int i = 0; i < N; ++i) cin >> c[i];

    // ソートする
    sort(a.begin(), a.end());
    sort(b.begin(), b.end());
    sort(c.begin(), c.end());

    // b[j] を固定して考える
    long long res = 0;
    for (int j = 0; j < N; ++j) {
        long long Aj = lower_bound(a.begin(), a.end(), b[j]) - a.begin();
        long long Cj = N - (upper_bound(c.begin(), c.end(), b[j]) - c.begin());
        res += Aj * Cj;
    }
    cout << res << endl;
}