けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AOJ 0660 たのしいたのしいたのしい家庭菜園 (JOI 2019 本選 C)

結構苦手系。想定解法かはわからないけどやってみた

問題へのリンク

問題概要

'R', 'G', 'Y' の 3 種類の文字で構成された長さ  N の文字列  S が与えられる。これに以下の操作を行って「隣り合う 2 文字が同じになることはない」ようにしたい。それが可能となる操作回数の最小値を求めよ (不可能ならば -1)。

  • 隣り合う 2 文字を swap する

制約

  •  1 \le N \le 400

考えたこと

まず「隣り合う 2 文字が同じにならないように」というのが可能かどうかの判定は容易に行える。

一番登場回数の多い文字の出現回数を  a として、 a+1 箇所の隙間に残りの文字をすべて入れられれば OK

なんとなく唐突に ARC 097 E - Sorted and Sorted を思い出した。これと同じような DP でできそうな気持ちになった。なんだろ、問題の目的は違うけど、隣接 swap は DP というお気持ち。

  • dp[ r ][ g ][ y ][ 3 ] := 左から r + g + y 個について、R が r 個、G が g 個、Y が y 個になるようにするための必要操作回数の最小値 (最後の引数は、最後の文字が R, G, Y のいずれかを表す)

とすればよさそう。遷移を考えるのがちょっと面倒。

まず、左から r + g + y 個についての R, G, Y の個数の内訳が決まった時点で、残りの文字列がどうなっているかは Greedy に一意に決まることに注意する。予め、

  • num[ i ][ c1 ][ c2 ] := 初期状態において、色が c1 な文字のうち i 番目のものの右側に色が c2 な文字が何個あったか

というのを持っておくとよさそう。そうすると、左から r + g + y 個の並びを決めたときに、例えば r+1 番目の R の右側にある R, G, Y の個数はそれぞれ

  • |R| - (r + 1) (R の個数を |R| とした)
  • min( num[ r + 1 ][ R ][ G ], |G| - g )
  • min( num[ r + 1 ][ R ][ Y ], |Y| - y )

となるので、これを合計したものを sum として N - 1 - sum 番目にあることがわかる。よって (r, g, y) から (r + 1, g, y) へと遷移するコストは

  • N - 1 - sum
  • r + g + y

との差で表されることとなる。

#include <iostream>
#include <string>
#include <cmath>
using namespace std;
void chmin(int &a, int b) { if (a > b) a = b; }
const int INF = 1LL<<29;
const int MAX = 202;

// 入力
int N;
string S;

// 前処理結果
int nG, nR, nY;
int num[MAX][3][3];

int calc_diff(int r, int g, int y, int color) {
    int sum = 0;
    if (color == 0) {
        sum = nR - (r+1) + min(num[r][0][1], nG-g) + min(num[r][0][2], nY-y);
    }
    else if (color == 1) {
        sum = nG - (g+1) + min(num[g][1][0], nR-r) + min(num[g][1][2], nY-y);
    }
    else {
        sum = nY - (y+1) + min(num[y][2][0], nR-r) + min(num[y][2][1], nG-g);
    }
    int res = abs((N-1-sum) - (r+g+y));
    return res;
}

int dp[MAX][MAX][MAX][3];
int solve() {
    // -1 の場合
    N = (int)S.size();
    nR = 0, nG = 0, nY = 0;
    for (int i = 0; i < N; ++i) {
        if (S[i] == 'R') ++nR;
        else if (S[i] == 'G') ++nG;
        else ++nY;
    }
    if (nR > (N+1)/2 || nG > (N+1)/2 || nY > (N+1)/2) return -1;
    
    // 前処理
    int num_r = 0, num_g = 0, num_y = 0;
    for (int i = N-1; i >= 0; --i) {
        if (S[i] == 'R') {
            int r_ind = nR - 1 - num_r;
            num[r_ind][0][0] = num_r;
            num[r_ind][0][1] = num_g;
            num[r_ind][0][2] = num_y;
            ++num_r;
        }
        else if (S[i] == 'G') {
            int g_ind = nG - 1 - num_g;
            num[g_ind][1][0] = num_r;
            num[g_ind][1][1] = num_g;
            num[g_ind][1][2] = num_y;
            ++num_g;
        }
        else {
            int y_ind = nY - 1 - num_y;
            num[y_ind][2][0] = num_r;
            num[y_ind][2][1] = num_g;
            num[y_ind][2][2] = num_y;
            ++num_y;
        }
    }
    
    // DP 初期化
    for (int i = 0; i < MAX; ++i)
        for (int j = 0; j < MAX; ++j)
            for (int k = 0; k < MAX; ++k)
                for (int l = 0; l < 3; ++l)
                    dp[i][j][k][l] = INF;
        
    // DP 初期条件
    dp[1][0][0][0] = calc_diff(0, 0, 0, 0);
    dp[0][1][0][1] = calc_diff(0, 0, 0, 1);
    dp[0][0][1][2] = calc_diff(0, 0, 0, 2);
        
    // DP
    for (int r = 0; r <= nR; ++r) {
        for (int g = 0; g <= nG; ++g) {
            for (int y = 0; y <= nY; ++y) {
                if (r < nR) {
                    chmin(dp[r+1][g][y][0], dp[r][g][y][1] + calc_diff(r, g, y, 0));
                    chmin(dp[r+1][g][y][0], dp[r][g][y][2] + calc_diff(r, g, y, 0));
                }
                if (g < nG) {
                    chmin(dp[r][g+1][y][1], dp[r][g][y][0] + calc_diff(r, g, y, 1));
                    chmin(dp[r][g+1][y][1], dp[r][g][y][2] + calc_diff(r, g, y, 1));
                }
                if (y < nY) {
                    chmin(dp[r][g][y+1][2], dp[r][g][y][0] + calc_diff(r, g, y, 2));
                    chmin(dp[r][g][y+1][2], dp[r][g][y][1] + calc_diff(r, g, y, 2));
                }
            }
        }
    }
    
    // 答え
    int res = INF;
    for (int i = 0; i < 3; ++i) chmin(res, dp[nR][nG][nY][i]);
    if (res < INF) return res;
    else return -1;
}
        
int main() {
    cin >> N >> S;
    cout << solve() << endl;
}