けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 311 G - One More Grid Task (3D, 黄色, 575 点)

黒マスを避けながら、長方形領域の値の総和を最大化する問題として解いた!

問題概要

 N \times M のグリッドがあって、各マス  (i, j) には正の整数  A_{i, j} が書かれている。

グリッドに含まれる長方形領域のうち、「長方形領域に含まれる値の総和」と「長方形領域に含まれる値の最小値」の積の最大値を求めよ。

制約

  •  1 \le N, M \le 300
  •  1 \le A_{i, j} \le 300

解法 (1) (僕自身の考えたこと)

コンテスト中、色々考えた。Cartesian 木が使えるんじゃないかとかも (使える解法も学んだので、後述する)。ただ、やっぱり 4 乗からの改善が難しく感じた。

やっぱり、 A_{i, j} \le 300 という制約が大事なんじゃないかと考えて、次の方針を考えた。


 v = 1, 2, \dots, 300 に対して、

  •  A_{i, j} \lt v なるマス  (i, j) は含まないように
  • 長方形領域をとって、その総和が最大となるようにする問題を考える
  • その最大値を  S(v) として、 v \times S(v) の最大値を答えればよい

 v に対して、この問題を  O(NM) で解ければよい。なんだかとても典型っぽい問題になった!

グリッド内の長方形領域の総和の最大値

問題をより簡単にして、「白黒グリッドにおいて、黒色マスを含まないように白色マスのみで長方形領域をとり、その面積の最大値を求めよ」という問題なら有名問題だ!!

https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_3_B&lang=ja

今回の問題は、各マスに重みがあるバージョンだと言える。その場合も同様の解法で解ける。

各行  i を固定して、そこから下の領域のみを考える。そうすると有名な「ヒストグラム内部の長方形」を考える問題になる。

次に、ヒストグラム内部の極大長方形を列挙する。具体的には、各列  j に対して、その列のヒストグラムの高さを左右に広げられるところまで広げる。どこまで広げられるのかは stack を使っ線形時間で求める有名な方法がある (蟻本などにある)。

 j に対して、上図のようにして求められる極大長方形について、それに含まれるマスの値の総和を求めていき、それらの最大値を求めればよい。

以上より、全体の計算量は  O(NM \max{A}) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// min value of H within [left[i], right[i]) <= H[i]
pair<vector<int>, vector<int>> solve_left_right(const vector<long long> &H) {
    int N = (int)H.size();
    vector<int> left(N, 0), right(N, N);
    
    // left
    stack<pair<long long, int>> stack_left;
    for(int i = 0; i < N; ++i) {
        while (!stack_left.empty() && H[i] <= stack_left.top().first)
            stack_left.pop();
        if (!stack_left.empty()) left[i] = stack_left.top().second + 1;
        stack_left.push({H[i], i});
    }
    
    // right
    stack<pair<long long, int>> stack_right;
    for(int i = N-1; i >= 0; --i) {
        while (!stack_right.empty() && H[i] <= stack_right.top().first)
            stack_right.pop();
        if (!stack_right.empty()) right[i] = stack_right.top().second;
        stack_right.push({H[i], i});
    }
    return {left, right};
}

int main() {
    // 入力
    int H, W;
    cin >> H >> W;
    vector<vector<long long>> A(H, vector<long long>(W));
    for (int i = 0; i < H; ++i) for (int j = 0; j < W; ++j) cin >> A[i][j];
    
    // 2 次元累積和
    vector<vector<long long>> S(H+1, vector<long long>(W+1, 0));
    for (int i = 0; i < H; ++i) {
        for (int j = 0; j < W; ++j) {
            S[i+1][j+1] = S[i+1][j] + S[i][j+1] - S[i][j] + A[i][j];
        }
    }
    
    // 長方形領域 [lx, rx) x [ly, ry) の総和
    auto calc = [&](int lx, int rx, int ly, int ry) -> long long {
        return S[min(rx, H)][min(ry, W)]
        - S[min(rx, H)][ly] - S[lx][min(ry, W)] + S[lx][ly];
    };
    
    // 最小値が v 以上の場合を求める (v 未満のマスを禁止する)
    long long res = 0;
    for (int v = 0; v <= 300; ++v) {
        // 各マスを起点として下側に可能マスが何個連続するかを求める
        vector<vector<long long>> len(H+1, vector<long long>(W, 0));
        for (int j = 0; j < W; ++j) {
            for (int i = H-1; i >= 0; --i) {
                if (A[i][j] < v) len[i][j] = 0;
                else len[i][j] = len[i+1][j] + 1;
            }
        }
        
        // 行 i を起点として、その下側のヒストグラムを考える
        for (int i = 0; i < H; ++i) {
            auto [left, right] = solve_left_right(len[i]);
            
            // マス (i, j) を起点として、下方向、左方向、右方向に限界まで伸ばした長方形領域の総和
            for (int j = 0; j < W; ++j) {
                long long sum = calc(i, i+len[i][j], left[j], right[j]);
                res = max(res, sum * v);
            }
        }
    }
    cout << res << endl;
}

 

解法 (2)

 A_{i, j} の値によらない解法がある模様。横方向の区間  \lbrack R, L) を固定して考えることにする。

このとき、各行  i について、区間  \lbrack R, L) 内部に含まれる値の総和を  B_{i}、最小値を  C_{i} とすると、次の問題を解けばよいことになる。


サイズ  N の数列  B_{i} C_{i} が与えられる。適切な区間   \lbrack l, r) を定めることで、

 \displaystyle (\sum_{i \in \lbrack l, r)}B_{i}) \times (\min_{i \in \lbrack l, r)}C_{i})

の値の最大化せよ。


これは、Cartesian 木上の DP で  O(N) の計算量で解決できる。よって、全体として  O(NM^{2}) の計算量で解ける。

コード

#include <bits/stdc++.h>
using namespace std;

// Cartesian Tree
template<class T> struct CartesianTree {
    int root;  // root
    vector<int> par, left, right;

    CartesianTree() {}
    CartesianTree(const vector<T>& v) : root(0)
    , par(v.size(), -1), left(v.size(), -1), right(v.size(), -1) {
        vector<int> st(v.size(), 0);
        int top = 0;
        for (int i = 1; i < v.size(); ++i) {
            if (v[st[top]] > v[i]) {
                while (top >= 1 && v[st[top - 1]] > v[i]) --top;
                par[left[i] = st[top]] = i;
                if (top == 0) root = i;
                else right[par[i] = st[top - 1]] = i;
                st[top] = i;
            } else {
                right[par[i] = st[top]] = i;
                st[++top] = i;
            }
        }
    }
};

int main() {
    // 入力
    const int INF = 1<<29;
    int N, M;
    cin >> N >> M;
    vector<vector<int>> A(N, vector<int>(M));
    for (int i = 0; i < N; ++i) for (int j = 0; j < M; ++j) cin >> A[i][j];
    
    // 横方向の区間 [L, R) ごとに解く
    long long res = 0;
    for (int L = 0; L < M; ++L) {
        vector<int> B(N, 0), C(N, INF);
        for (int R = L; R < M; ++R) {
            // solve problem [L, R)
            for (int i = 0; i < N; ++i) {
                B[i] += A[i][R];
                C[i] = min(C[i], A[i][R]);
            }
             
            // build cartesian tree
            CartesianTree<int> ct(C);
            
            // build ruisekiwa
            vector<long long> S(N+1, 0);
            for (int i = 0; i < N; ++i) S[i+1] = S[i] + B[i];
            
            // solve [l, r)
            auto dfs = [&](auto self, int l, int r, int cur) -> long long {
                long long res = (S[r] - S[l]) * C[cur];
                if (ct.left[cur] != -1) {
                    res = max(res, self(self, l, cur, ct.left[cur]));
                }
                if (ct.right[cur] != -1) {
                    res = max(res, self(self, cur+1, r, ct.right[cur]));
                }
                return res;
            };
            res = max(res, dfs(dfs, 0, N, ct.root));
        }
    }
    cout << res << endl;
}