けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

yukicoder No.2423 Merge Stones

ビットベクターで高速化するのは、いつも見落としてしまう!

問題概要

円環状に  N 個の魔法石があり、 i 番目の魔法石は、価値が  A_{i}、色が  C_{i} である。色は 1 以上 50 以下の整数で表される。

今、これらの魔法石に対して、隣り合う魔法石の色の差が  K 以下であるとき、以下の操作を実行することができる。

  • 2 つの魔法石を合体させて 1 つの魔法石にする
  • 合体された魔法石の価値は、合体させた 2 つの魔法石の価値の和となる
  • 合体された魔法石の色は、合体させた 2 つの魔法石の色のいずれかの色にすることができる

操作の過程で発生する魔法石の価値として考えられる最大値を求めよ。

制約

  •  1 \le N \le 300
  •  0 \le K \le 5
  •  1 \le A_{i} \le 10^{9}
  •  1 \le C_{i} \le 50

考えたこと

円環状であることに関しては、数列を 2 週させた上で「最大で  N 個の連続した魔法石を合体できる」という制約にすれば OK。以降、一直線状の問題として考え直す。

いかにも区間 DP な設定なので、区間 DP を考えてみる。


dp[l][r][c] ← 区間  \lbrack l, r) の魔法石をすべて合体することで、色  c の 1 つの魔法石を作ることが可能かどうか (True / False)


例によって、こんな感じの DP ができる。

for (int between = 2; between <= N; ++between) {
    for (int l = 0; l < N*2 && l + between <= N*2; ++l) {
        int r = l + between;

        // 区間 [l, r) を [l, m) と [m, r) に分割する
        for (int m = l+1; m < r; ++m) {
           for (int c1 = 1; c1 <= 50; ++c1) {
               for (int c2 = max(0, c1-K); c2 <= min(50, c1+K); ++c2) {
                   if (dp[l][m][c1] && dp[m][r][c2]) {
                      dp[l][m][c1] = true;
                      dp[l][m][c2] = true;
                   }
                }
            }
        }
    }
}

しかしこのままでは、色の種類数を  M として、全体で  O(N^{3}KM) の計算量を要して間に合わない。

ビットベクター高速化

僕は高速化できずに悩んでいたけど、ビットベクター高速化はいつも忘れてしまう。

dp[l][r] を要素数  M(= 50)bitset であるとして、dp の更新を bitset の演算でまとめて処理すれば、実質的な計算量は  O(N^{3}K) となる。

具体的には、次のコードのように実現できる。

コード

#include <bits/stdc++.h>
using namespace std;

int main() {
    // 入力 (数珠なので、2 週させる)
    int N, K;
    cin >> N >> K;
    vector<long long> A(N*2), C(N*2);
    for (int i = 0; i < N; ++i) { cin >> A[i]; A[i+N] = A[i]; }
    for (int i = 0; i < N; ++i) { cin >> C[i]; C[i+N] = C[i]; }
    
    // 累積和
    vector<long long> S(N*2+1, 0);
    for (int i = 0; i < N*2; ++i) S[i+1] = S[i] + A[i];
    
    // dp
    vector dp(N*2, vector(N*2+1, bitset<52>(false)));
    for (int l = 0; l < N*2; ++l) dp[l][l+1][C[l]] = true;
    for (int between = 2; between <= N; ++between) {
        for (int l = 0; l + between <= N*2; ++l) {
            int r = l + between;
            for (int m = l+1; m < r; ++m) {
                for (int k = 0; k <= K; ++k) {
                    dp[l][r] |= dp[l][m] & (dp[m][r] << k);
                    dp[l][r] |= dp[l][m] & (dp[m][r] >> k);
                    dp[l][r] |= (dp[l][m] << k) & dp[m][r];
                    dp[l][r] |= (dp[l][m] >> k) & dp[m][r];
                }
            }
        }
    }
    
    // 集計
    long long res = 0;
    for (int l = 0; l <= N; ++l) {
        for (int r = l+1; r <= N*2; ++r) {
            for (int c = 0; c <= 50; ++c) {
                if (dp[l][r][c]) res = max(res, S[r] - S[l]);
            }
        }
    }
    cout << res << endl;
}