けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder AGC 033 D - Complexity (赤色, 1000 点)

これを解けなかったのが強い敗北感。
DP 配列が巨大になりそうなときに、最適化する対象を入れ替えるテクは今までなんども見ているのにそれが思いつかない思考の硬さを思い知らされた。

問題へのリンク

問題概要

0 と 1 のみからなる行列の複雑度を

  • すべて同じ要素からなる行列の複雑度は  1
  • 一般の行列の複雑度を、縦方向または横方向に分割して 2 つの行列に分ける方法をすべて考えたときの「2 つの行列の複雑度の最大値」の最小値に 1 を足したもの

として定義する。与えられた  H ×  W の 0-1 行列の複雑度を求めよ。ただしメモリ制限は 512MB とする。

制約

  •  1 \le H, W \le 185

考えたこと

一目見て区間 DP したい気持ちになる。すなわち

  • dp[ h1 ][ h2 ][ w1 ][ w2 ] := 長方形領域の [h1, h2) × [w1, w2) の複雑度

とする。しかしこれはそもそもメモリに乗らない上に  O(H^{3}W^{3}) の計算時間がかかってしまう。ここで注目すべきことは、

  • 複雑度は最悪でも、行列サイズの対数オーダーにしかならない

ということだ。そこでこれを利用して DP の最適化したい対象を入れ替えるというテクニックが炸裂できる。ナップサック問題で「ナップサック容量がでかいが、各品物の価値は小さい」という場合に使えるテクとして有名。

qiita.com

で、このテクを使うと、

  • dp1[ h1 ][ h2 ][ w1 ][ k ] := 区間 [h1, h2) × [w1, w2) の複雑度が k 以下となる最大の w2
  • dp2[ h1 ][ w1 ][ w2 ][ k ] := 区間 [h1, h2) × [w1, w2) の複雑度が k 以下となる最大の h2

とすると良さそう。k は log オーダーなので、これでメモリにもおさまる。これを用いて dp の更新は、以下のようにできる。

縦に切れ目を入れる場合

  • w = dp1[ h1 ][ h2 ][ w1 ][ k ] として、chmax(dp1[ h1 ][ h2 ][ w1 ][ k + 1 ], dp1[ h1 ][ h2 ][ w ][ k ])

  • 二分探索して、w = dp1[ h1 ][ h ][ w1 ][ k ] として dp1[ h1 ][ h ][ w ][ k ] >= w2 を満たす最大の h を求めて、chmax(dp2[ h1 ][ w1 ][ w2 ][ k + 1 ], h)

横に切れ目を入れる場合

  • 二分探索して、h = dp2[ h1 ][ w1 ][ w ][ k ] として dp2[ h ][ w1 ][ w ][ k ] >= h2 となる最大の w を求めて、chmax(dp1[ h1 ][ h2 ][ w1 ][ k + 1 ], w)

  • h = dp2[ h1 ][ w1 ][ w2 ][ k ] として、chmax(dp2[ h1 ][ w1 ][ w2 ][ k + 1 ], dp2[ h ][ w1 ][ w2 ][ k ])

さらに

二分探索で更新したところはしゃくとり法っぽくやれば  O(1) で更新できる。

#include <iostream>
#include <vector>
#include <string>
#include <cstring>
using namespace std;

template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

int H, W;
vector<string> fi;
const int MAX_V = 200;
const int MAX = 20;

long long sum[MAX_V][MAX_V];
int dp1[MAX_V][MAX_V][MAX_V][2];
int dp2[MAX_V][MAX_V][MAX_V][2];

int solve() {
    // 初期化
    sum[0][0] = 0;
    for (int h = 0; h < H; ++h) {
        for (int w = 0; w < W; ++w) {
            sum[h+1][w+1] = -sum[h][w] + sum[h+1][w] + sum[h][w+1] + (fi[h][w] == '#');
        }
    }
    for (int h1 = 0; h1 <= H; ++h1) {
        for (int h2 = h1; h2 <= H; ++h2) {
            int w2 = 0;
            for (int w1 = 0; w1 <= W; ++w1) {
                if (w2 < w1) w2 = w1;
                while (w2 < W &&
                       (sum[h2][w2+1]-sum[h1][w2+1]-sum[h2][w1]+sum[h1][w1] == 0 ||
                        sum[h2][w2+1]-sum[h1][w2+1]-sum[h2][w1]+sum[h1][w1] == (h2-h1)*(w2+1-w1)))
                    ++w2;
                dp1[h1][h2][w1][0] = w2;
            }
        }
    }
    for (int w1 = 0; w1 <= W; ++w1) {
        for (int w2 = w1; w2 <= W; ++w2) {
            int h2 = 0;
            for (int h1 = 0; h1 <= H; ++h1) {
                if (h2 < h1) h2 = h1;
                while (h2 < H &&
                       (sum[h2+1][w2]-sum[h1][w2]-sum[h2+1][w1]+sum[h1][w1] == 0 ||
                        sum[h2+1][w2]-sum[h1][w2]-sum[h2+1][w1]+sum[h1][w1] == (h2+1-h1)*(w2-w1)))
                    ++h2;
                dp2[h1][w1][w2][0] = h2;
            }
        }
    }

    // dp
    int res = MAX;
    for (int k = 0; k+1 < MAX; ++k) {
        // dp1
        for (int h1 = 0; h1 <= H; ++h1) {
            for (int h2 = h1; h2 <= H; ++h2) {
                for (int w1 = 0; w1 <= W; ++w1) {
                    int w = dp1[h1][h2][w1][k%2];
                    chmax(dp1[h1][h2][w1][(k+1)%2], dp1[h1][h2][w][k%2]);
                }
                int w2 = 0;
                for (int w1 = 0; w1 <= W; ++w1) {
                    if (w2 < w1) w2 = w1;
                    while (w2 < W && dp2[dp2[h1][w1][w2+1][k%2]][w1][w2+1][k%2] >= h2)
                        ++w2;
                    chmax(dp1[h1][h2][w1][(k+1)%2], w2);
                }
            }
        }
        // dp2
        for (int w1 = 0; w1 <= W; ++w1) {
            for (int w2 = w1; w2 <= W; ++w2) {
                for (int h1 = 0; h1 <= H; ++h1) {
                    int h = dp2[h1][w1][w2][k%2];
                    chmax(dp2[h1][w1][w2][(k+1)%2], dp2[h][w1][w2][k%2]);
                }
                int h2 = 0;
                for (int h1 = 0; h1 <= H; ++h1) {
                    if (h2 < h1) h2 = h1;
                    while (h2 < H && dp1[h1][h2+1][dp1[h1][h2+1][w1][k%2]][k%2] >= w2)
                        ++h2;
                    chmax(dp2[h1][w1][w2][(k+1)%2], h2);
                }
            }
        }

        if (dp1[0][H][0][k%2] == W) chmin(res, k);
        if (dp2[0][0][W][k%2] == H) chmin(res, k);

        for (int x = 0; x < MAX_V; ++x) {
            for (int y = 0; y < MAX_V; ++y) {
                for (int z = 0; z < MAX_V; ++z) {
                    dp1[x][y][z][k%2] = 0;
                    dp2[x][y][z][k%2] = 0;
                }
            }
        }       
    }
    return res;
}

int main() {
    while (cin >> H >> W) {
        fi.resize(H);
        for (int i = 0; i < H; ++i) cin >> fi[i];
        cout << solve() << endl;
    }
}