けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 143 F - Distinct Numbers (600 点)

かなり辛いことを頑張ったけど、本当はすごく明快な問題だった!!

問題へのリンク

問題概要

 N 個の整数  A_1, A_2, \dots, A_N がある。各  K = 1, 2, \dots, N に対して、以下の答えを求めよ。

  • 残っている整数から、どの 2 つも互いに異なる  K 個の整数を選んで抜き取ることを繰り返したい
  • 抜き取れる回数の最大値を求めよ

制約

  •  1 \le N \le 3 \times 10^{5}

解法 1: 初手二分探索解法

ある種の問題は、とにかく初手で二分探索を決め打つと見通しがよくなって解けたりする。今回もまさにそういう問題といえるかもしれない。各  K に対して、


 x 回 (以上) 抜き取れるかを判定せよ


という判定問題を考えて、これを満たす最大の  x を二分探索で求めればよい。

まず整数  Aヒストグラム化する。すなわち  A = (1, 1, 1, 2, 3, 3) とかだったら、同じ整数ごとに個数をまとめると (1 個、2 個、3 個) となる。この配列を改めて  C とする。

このとき、まず言えることは  C のうち  x より多い値になっているところは、最後は余ることになる。よって

  •  E_i = \min(C_i, x)

とする。このとき、

  •  E_i の総和が  Kx 未満だったらそもそもダメ
  • 逆に  E_i の総和が  Kx 以上であれば、 K 個の整数を  x 回抜き取ることができる

ということがわかる。よって条件は、


 \sum_{i} \min(C_i, x) \ge Kx


と簡潔に表すことができる。この判定は  O(N) なので、二分探索フレームワーク全体を含めて、 O(N\log{N}) でできる。各  K に対してこれを行ってしまうと  O(N^{2} \log{N}) となるので高速化する。

高速化

上記の条件の左辺は、 K を含まないので使い回すことができる。すなわち

  •  S_{x} = \sum_{i} \min(C_i, x)

を満たすような配列  S が予め求められれば OK。やり方としては

  •  S_{x-1} から  S_{x} になるに従って増える分は、 C のうち  x 以上のものの個数である

ということに着目する。この個数を  D_{x} として、累積和をとると  S_{x} になる。

こうして全体として、 O(N\log{N}) で求められるようになった。

#include <iostream>
#include <vector>
using namespace std;

int main() {
    // C: A のカウンティング配列
    // D[x] := C の中に x 以上が何個あるか
    // S: D の累積和
    int N; cin >> N;
    vector<long long> C(N, 0);
    vector<long long> D(N+1, 0);
    for (int i = 0; i < N; ++i) {
        int a; cin >> a; --a;
        C[a]++;
        D[C[a]]++;
    }
    vector<long long> S(N+1, 0);
    for (int i = 1; i < N+1; ++i) S[i] = S[i-1] + D[i];

    for (long long K = 1; K <= N; ++K) {
        long long low = 0, high = N+1;
        while (high - low > 1) {
            long long x = (low + high) / 2;
            if (S[x] >= K * x) low = x;
            else high = x;
        }
        cout << low << endl;
    }
}

解法 2: さらに高速化して O(N) に

よく考えると実は二分探索が要らないことに気づいたりするみたい。条件

 \sum_{i} \min(C_i, x) \ge Kx

をちょっと変形して

 K \le \sum_{i} \min(C_i, x) / x

とする。各  x に対して右辺を求めておく。そうすると各  K に対するクエリは

  •  K \le \sum_{i} \min(C_i, x) / x を満たすような最大の  x を求める

とすればよい。各  K に対してそのような  x は単調減少なので、 x N から下げていくようにループを回していくことで、全体として  O(N) でできる。

#include <iostream>
#include <vector>
using namespace std;

int main() {
    // C: A のカウンティング配列
    // D[x] := C の中に x 以上が何個あるか
    // S: D の累積和
    // S2[x] = S[x] / x
    int N; cin >> N;
    vector<long long> C(N, 0);
    vector<long long> D(N+1, 0);
    for (int i = 0; i < N; ++i) {
        int a; cin >> a; --a;
        C[a]++;
        D[C[a]]++;
    }
    vector<long long> S(N+1, 0), S2(N+1, 0);
    for (int i = 1; i < N+1; ++i) S[i] = S[i-1] + D[i], S2[i] = S[i] / i;

    long long x = N;
    for (long long K = 1; K <= N; ++K) {
        while (x > 0 && K > S2[x]) --x;
        cout << x << endl;
    }
}

解法 3 (?): 僕が実際やったときの解法

超厄介で、上手く説明できない。ゴチャゴチャやったら通った...コードのみ。

観察を述べると...
上記の  C を小さい順にソートしたとき、

  • C = (1, 2, 8)

とかだと、 K = 2 だったら、8 以外で 3 個しかないので、3 組しか作れない

  • C = (1, 2, 6, 100)

とかだと、 K = 3 だったら、100 以外の (1, 2, 6) で  K = 2 の場合を解いて、これは 3 なので、3 組しか作れない。一方

  • C = (10, 11, 12, 15)

とかだと、 K = 3 だったら、15 以外で (10, 11, 12) で色々できるので、この場合は (10 + 11 + 12 + 15) / 3 = 16 組作れる。

以上の考察に基づいて、C のサイズ 1 個減らして、K も 1 個減らした場合に帰着されることになる。C を前から頑張って更新していく感じ。このままだと  O(N^{2}) かかるので、情報を圧縮して、高速化して提出したら通った。

#include <iostream>
#include <vector>
#include <deque>
#include <map>
#include <algorithm>
using namespace std;

int main() {
    int N; cin >> N;
    map<long long, long long> ma;
    for (int i = 0; i < N; ++i) {
        long long a; cin >> a, ma[a]++;
    }
    vector<long long> v;
    for (auto it : ma) v.push_back(it.second);
    sort(v.begin(), v.end());

    
    long long sum = v[0];
    long long num = 1;
    deque<long long> d;
    for (int i = 1; i < v.size(); ++i) {
        // v[i] 以下が何個か
        long long low = 0, high = N+1;
        while (high - low > 1) {
            long long mid = (low + high) / 2;
            long long val;
            if (mid > d.size()) {
                long long div = num - (mid - (int)d.size() - 1);
                if (div <= 0) val = 1LL<<50;
                else val = sum / div;
            }
            else val = d[(int)d.size() - mid];
            
            if (val >= v[i]) high = mid;
            else low = mid;
        }
        
        while (d.size() > low) d.pop_front();
        while (d.size() < low) {
            long long div = i - (int)d.size();
            d.push_front(sum / div);
        }
        sum += v[i];
        num = i+1 - low;
    }
    
    for (int i = 0; i < num; ++i) cout << sum / (i+1) << endl;
    for (int i = 0; i < d.size(); ++i) cout << d[i] << endl;
    for (int i = 0; i < N - num - (int)d.size(); ++i) cout << 0 << endl;       
}