けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 127 D - Integer Cards (400 点)

混ぜてソートは賢すぎる!!!惚れた!!!

問題へのリンク

問題概要

 N 枚のカードにそれぞれ  A_1, A_2, \dots, A_N の数値が書かれている。あなたは、 j = 1, 2, \dots, M について順に以下の操作を 1 回ずつ行います。

  • カードを  B_j 枚まで選ぶ(0 枚でもよい)。選んだカードに書かれている整数をそれぞれ  C_j に書き換える。

操作終了後のカードの数値の総和の最大値を求めよ。

制約条件

  •  1 \le N, M \le 10^{5}

考えたこと 1

ややこしいけど、こういうのはちゃんと順を追って考えれば行ける。とりあえず思うのは

  • もともとが小さいやつを優先的に書き換えたい
  • なるべく大きいやつに書き換えたい

このあたりのことをきちんと整理しながら考える。具体的には「変数を固定して考える」という戦略が良さそうに思える。まず、 N 枚のカードのうち  K 枚を書き換えるとすると、

  •  A_{1}, A_{2}, \dots, A_{N} のうち小さい順に  K 枚を書き換える
  •  (C_{1}, \dots, C_{1}), (C_{2}, \dots, C_{2}), \dots, (C_{M}, \dots, C_{M}) のうち大きい順に  K 枚を書き換える

とすればよいことがわかる。まとめると

  •  A_{1}, A_{2}, \dots, A_{N} を小さい順にソート
  •  (C_{1}, \dots, C_{1}), (C_{2}, \dots, C_{2}), \dots, (C_{M}, \dots, C_{M}) を大きい順にソート
  •  i = 0, 1, ..., N-1 に対して、 \max ( A i 番目,  C i 番目) を合計

したものが答えになる。ただし、 C の個数が  N に満たない場合は、足りない部分は単純に  A の方を足していく。

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
using pll = pair<long long, long long>;

int main() {
    int N, M; cin >> N >> M;
    vector<long long> A(N), B(M), C(M);
    for (int i = 0; i < N; ++i) cin >> A[i];
    sort(A.begin(), A.end());
    B.resize(M);
    C.resize(M);
    for (int i = 0; i < M; ++i) cin >> B[i] >> C[i];

    // C をソート (B をまとめて)
    vector<int> id(M);
    iota(id.begin(), id.end(), 0);
    sort(id.begin(), id.end(), [&](int i, int j) {
            return C[i] > C[j];});

    // A (小さい順) と C (大きい順) とを比べて大きい方を足していく
    long long sum = 0;
    long long K = 0;
    for (auto i : id) {
        for (int j = 0; j < B[i]; ++j) {
            if (K >= N) break;
            sum += max(A[K++], C[i]);  
        }
    }
    for (int i = K; i < N; ++i) sum += A[i];        
    cout << sum << endl;
}

解法 2

実はすごく簡単に

  •  A_{1}, A_{2}, \dots, A_{N}
  •  (C_{1}, \dots, C_{1}), (C_{2}, \dots, C_{2}), \dots, (C_{M}, \dots, C_{M})

を全部混ぜて大きい順に  N 個とった合計値で OK。このとき全部混ぜたものの個数は  O(NM) になってまともに扱えないが、ランレングス圧縮した世界で考えれば OK。

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
using pll = pair<long long, long long>;

int main() {
    int N, M; cin >> N >> M;
    vector<pll> v;
    for (int i = 0; i < N; ++i) {
        int a; cin >> a;
        v.push_back({a, 1});
    }
    for (int i = 0; i < M; ++i) {
        int b, c; cin >> b >> c;
        v.push_back({c, b});
    }
    sort(v.begin(), v.end(), greater<pll>());

    long long num = 0;
    long long res = 0;
    for (auto p : v) {
        if (num + p.second <= N) {
            res += p.first * p.second;
            num += p.second;
        }
        else {
            res += p.first * (N - num);
            num = N;
            break;
        }
    }
    cout << res << endl;
}