けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

CS Academy 079 DIV2 E - Smallest Subsets

CS Academy 079 DIV2 E Smallest Subsets


問題概要

N 個の整数 S[0], S[1], ..., S[N-1] が与えられる。これらの部分和 (2N 通り) を小さい順に K 個出力せよ。

  • 1 <= N <= 105
  • 1 <= K <= min(2N, 105)
  • -109 <= S[i] <= 109

解法

S[i] として負の値もあるのがイヤであるが、以下のように鮮やかに解決できる:

(-9, -7, -5, 2, 8) の部分和として負の値の総和 (-21) を基準として

  • (-9) + (-7) + (-5) + 2 は、(-21) + 2 とみなす
  • (-7) + 2 + 8 は、(-21) + |-9| + |-5| + 2 + 8 とみなす
  • 2 + 8 は、(-21) + |-9| + |-7| + |-5| + 2 + 8 とみなす

つまり、負の数は絶対値をとり、「選ばない」やつは選んだことにして、「選んだ」やつは選ばなかったことにします。こうすることで、すべて非負の問題に帰着することができました。

しかしそれでも難しそうな問題に思えました。部分和というと DP したくなるのですが、|S[i]| <= 109 とあって無理です。S[i] が小さければ部分和 DP すれば、部分和が w となるような組の個数を求めることができるので、この問題も解けそうです。

さて、そのような方針が取れないので、2N 通りのうち前半の K 通りになりそうなところを効率良く探索する必要があります。例えば直感として、2k > K になるならば、k 個の部分和がランクインすることはありえないです。

以下のように賢く探索できるようです (プロのコードを参考にしました)。とりあえず S を小さい順にソートしておきます。最初に 0 (なにもない) は出力しておいて、priority_queue に S[0] (0番目のみ) を加えておいて、

vector<long long> res;
res.push_back(0);  // とりあえず何もない
priority_queue<pair<long long, int>,
                   vector<pair<long long, int> >,
                   greater<pair<long long, int> > > que;
que.push(make_pair(S[0], 0));
while (!que.empty()) {
  if (res.size() == K) break;  // K 個に達するまで
  pair<long long, int> p = que.top();
  que.pop();
  res.push_back(p.first);
  if (p.second < N-1) {  // これが決定的に効くみたい
    que.push(make_pair(p.first + S[p.second+1], p.second + 1));
    que.push(make_pair(p.first + S[p.second+1] - S[p.second], p.second + 1));
  }
}

とすればよいようです。気持ちとしては、S の各部分和として、一番大きな index で分類して管理するようです。いろんなケースを考えてみると、確かに効率良く枝刈りしまくっている感はります。S の前半をとりまくってるようなやつはなかなか que から pop されず、前半をあまり取らずに index をどんどん先に進めてるようなやつが優先的に pop されていくので、そういうやつがどんどん index = N に達して探索のリーフにたどり着けるようになっています。

なお、この探索によって全部分和を行き渡ること自体は、数学的帰納法によって示せます。この探索で計算量解析をきちんとやるとどうなるのか気になっています。

最終的な提出コードでは、冒頭の「負の数の処理」も含めています。

#include <iostream>
#include <vector>
#include <algorithm>
#include <queue>
using namespace std;

int N, K;
vector<long long> S;

int main() {
  while (cin >> N >> K) {
    S.resize(N);
    for (int i = 0; i < N; ++i) cin >> S[i];

    long long minussum = 0;
    for (int i = 0; i < N; ++i) {
      if (S[i] < 0) {
        minussum += S[i];
        S[i] = -S[i];
      }
    }
    sort(S.begin(), S.end());

    if (K == 1) {
      cout << minussum << endl;
      return 0;
    }
    
    vector<long long> res;
    priority_queue<pair<long long, int>,
                   vector<pair<long long, int> >,
                   greater<pair<long long, int> > > que;
    res.push_back(minussum);
    que.push(make_pair(minussum + S[0], 0));
    while (!que.empty()) {
      if (res.size() == K) break;
      pair<long long, int> p = que.top();
      que.pop();
      res.push_back(p.first);
      if (p.second < N-1) {
        que.push(make_pair(p.first + S[p.second+1], p.second + 1));
        que.push(make_pair(p.first + S[p.second+1] - S[p.second], p.second + 1));
      }
    }

    for (int i = 0; i < res.size(); ++i) {
      cout << res[i] << endl;
    }     
  }
}