けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 155 D - Pairs (400 点)

最初は、問題を見た瞬間に「にぶたんだ...」となったので、青 diff に驚いたのだった。でもいざ実装を始めると、頭壊れる問題ですね。。。^^;

問題へのリンク

問題概要

 N 個の整数  A_{1}, \dots, A_{N} が与えられる。これらから 2 つを選んで積をとって得られる  \frac{N(N-1)}{2} 個の整数のうち、小さい順に  K 番目の値を求めよ。

制約

  •  2 \le N \le 2 \times 10^{5}
  •  -10^{9} \le A_{i} \le 10^{9}

考えたこと

小さい順に  K 番目の値を求めよ」と言われたときに、頻出の一つの考え方がある。それは、問題を次のように言い換えることだ。


 x 以下のものが  K 個以上あるような、最小の  x を求めよ


この言い換えを行うと、二分探索で求められるということがわかる!!!つまり、次の判定問題を、 \log 回解けば良いのだ。


 x 以下のものが  K 個以上あるかどうかを判定せよ。


判定パート

 A_{1}, \dots, A_{N} から 2 つ選んで積をとった値のうち、 x 以下のものが何個あるかを求める問題となった。依然として  O(N^{2}) から落とすことが課題となるけど、かなり考えやすくなった。

こういうシチュエーションを扱う定石として、「1 つを固定したときに、条件を満たす相方が何個あるのかを求める」というのがある。注意点として、

  •  A_{i} の相方として  A_{j} を選ぶ
  •  A_{j} の相方として  A_{i} を選ぶ

というのは重複してしまうので、最後に 2 で割ることにする。あと、

  •  A_{i} の相方として  A_{i} 自身を選ぶ

というのもまとめて数えてしまっていい。最後にそれも除く。まとめると、以下の手順で求めることができる。


  1.  A_{i} について、 A_{1}, \dots, A_{N} のうち、掛けて  x 以上になるやつが何個あるかを求める (その合計値を  S とする)
  2.  A_{i} \times A_{i} \le x となるようなものが何個あるかを求める  T とする
  3. このとき、求める個数は  \frac{S - T}{2} となる

2 は  O(N) でできる。1 について考える。    

Ai > 0 のとき

このとき、 A_{i} \times y という値は「 y について単調増加」であることに注意しよう。

よって、 A_{1}, A_{2}, \dots, A_{N} をソートしておけば、

  •  A_{i} \times A_{j} \le x を満たすような  j の最大値

を、再び二分探索で求めることができる。なお、この部分を std::lower_bound() を使う手もありそう...だけど、 \frac{x}{A_{i}} を計算する必要があって、C++ だと、この値は  x の正負によって挙動が変わるので場合わけが面倒だと思う。。。

Ai < 0 のとき

今度は反対に、 A_{i} \times y という値は「 y について単調減少」であることに注意しよう。

よって、 A_{1}, A_{2}, \dots, A_{N} をソートしておけば、

  •  A_{i} \times A_{j} \le x を満たすような  j の最小値

を、再び二分探索で求めることができる。

まとめ

まとめると、

  1. 大元の問題を二分探索に落とす
  2. 判定パートを、A[ i ] を固定しながら処理する

計算量は、判定問題を解くのに  O(N\log{N}) の計算量を要するので、全体で  O(N\log{N}\log{B}) となる ( B は登場する値の最大値)。

なお別解として、2. の部分の処理にしゃくとり法を用いる方法もありそう。

#include <bits/stdc++.h>
using namespace std;

const long long INF = 1LL<<61;
long long N, K;
vector<long long> A;

long long solve() {
    sort(A.begin(), A.end());
    long long left = -INF, right = INF;
    while (right - left > 1) {
        long long x = (left + right) / 2;
        long long S = 0, T = 0;
        for (int i = 0; i < N; ++i) {
            if (A[i] > 0) {
                int l2 = -1, r2 = N;
                while (r2 - l2 > 1) {
                    int m = (l2 + r2) / 2;
                    if (A[i] * A[m] <= x) l2 = m;
                    else r2 = m;
                }
                S += r2;
            }
            else if (A[i] < 0) {
                int l2 = -1, r2 = N;
                while (r2 - l2 > 1) {
                    int m = (l2 + r2) / 2;
                    if (A[i] * A[m] <= x) r2 = m;
                    else l2 = m;
                }
                S += N - r2;
            }
            else if (x >= 0) S += N;
            if (A[i] * A[i] <= x) ++T;
        }
        long long num = (S - T) / 2;
        if (num >= K) right = x;
        else left = x;
    }
    return right;
}

int main() {
    cin >> N >> K;
    A.resize(N);
    for (int i = 0; i < N; ++i) cin >> A[i];
    cout << solve() << endl;
}