けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 062 D - 3N Numbers (ARC 074 D) (青色, 500 点)

昨日の ABC で、「左右両端からの累積和」を使うと良い問題が出たので、その発展的類題の紹介に。

drken1215.hatenablog.com

問題へのリンク

問題概要

 N を正の整数とする。
 3N 個の要素からなる数列  a_{1}, \dots, a_{3N} にとおいて、 N 個の要素を取り除き、残った  2N 個の要素のうち (前半の  N 個の総和) - (後半の  N 個の総和) として考えられる最大値を求めよ。

制約

  •  1 \le N \le 10^{5}

考えたこと

少し考えると、 3N 個の要素を

  • 左から left 個のところで分割して
  • (左 left 個のうち大きい順に  N 個の和) - (右  N - left 個のうち小さい順に  N 個の和)

を求める作業を left =  N, N+1, \dots, 2N について行ってその最大値を求めればよいことがわかる。

左右両端からの前処理

というわけなので、

  • S[ i ] := 左 i 個分の大きい順に  N 個の総和
  • T[ i ] := 右 i 個分の小さい順に  N 個の総和

をそれぞれ前処理で計算できれば、この問題は解けることになる。このように左右両方から累積情報を前処理しておく方法は典型ではあって、応用範囲がとても広い。

S の求め方がわかれば T も同じなので S を求める。これは priority_queue を使うといい感じにできる。つまり

  • まず最初の  N 個の要素を突っ込んで初期化
  • 毎回、新たな要素 v が加わったときに、キューに入っている最小の値を mi として、
    • v > mi なら mi を破棄して v を新たに加え
    • v <= mi なら v を破棄してそのまま

とすることで上手く行く。T についてもほぼ同様。今度はキューには入っている最大の値と v とを比較することになる。

#include <iostream>
#include <vector>
#include <queue>
using namespace std;

int main() {
    int N; cin >> N;
    vector<long long> a(N*3);
    for (int i = 0; i < N*3; ++i) cin >> a[i];

    // 左から前処理
    vector<long long> S(N*2+1, 0);
    priority_queue<long long, vector<long long>, greater<long long> > que1;
    for (int i = 0; i < N; ++i) {
        S[i+1] = S[i] + a[i];
        que1.push(a[i]);
    }
    for (int i = N; i < N*2; ++i) {
        long long mi = que1.top();
        if (a[i] > mi) {
            S[i+1] = S[i] - mi + a[i];
            que1.pop();
            que1.push(a[i]);
        }
        else S[i+1] = S[i];
    }

    // 右から前処理
    vector<long long> T(N*2+ 1, 0);
    priority_queue<long long> que2;
    for (int i = 0; i < N; ++i) {
        T[i+1] = T[i] + a[N*3-1 - i];
        que2.push(a[N*3-1 - i]);
    }
    for (int i = N; i < N*2; ++i) {
        long long ma = que2.top();
        if (a[N*3-1 - i] < ma) {
            T[i+1] = T[i] - ma + a[N*3-1 - i];
            que2.pop();
            que2.push(a[N*3-1 - i]);
        }
        else T[i+1] = T[i];
    }

    // 集計
    long long res = -(1LL<<60);
    for (int i = N; i <= N*2; ++i) {
        res = max(res, S[i] - T[N*3-i]);
    }
    cout << res << endl;
}