けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #584 - Dasha Code Championship E. Rotate Columns (R2400)

すんごく横長な行列に関する問題だけど、実は横の長さを縦の長さ以下にできるというタイプ。

そうすれば  n \times n の問題に帰着できて、 n \le 12。あとは  O(3^{n}) の bit DP...なのだが、TLE がとれなかった。微妙に詰めが甘かった。。。

問題へのリンク

問題概要

 n \times m の行列があたえられる。各列に対して "巡回操作" を好きな回数だけ施すことができる。

それを行ったのちの、「各行の値の最大値の総和」の最大値を求めよ。

制約

  • (テストケース数)  \le 40
  •  1 \le n \le 12
  •  1 \le m \le 2000

考えたこと

 n に関する制約がいかにも bitDP だといっている...愚直にやると  O(m3^{n}) な bitDP になる。このままだと間に合わない。

しかし、 nm 個の整数の中から選ばれる数は  n 個しかないので、横長行列の  m 列のうち、最終的に関わってくるのは高々  n 列しかない。これをなんとか生かしたい。

少し不安ながらも、

  • 各列の最大値を  b_1, b_2, \dots, b_m として
  • これが大きい順に  n 個を選ぶ

という解は少なくとも実行可能であって、最適解はこれより小さくなることはないことに着目する。そして少し考えると、この選んだ  n 個の値が属する列ベクトル (それらの列の index を  j_1, \dots, j_n とする) 以外の列ベクトルを選ぶような解に対しては、かならず解を悪化させることなく、 j_1, \dots, j_n の範囲から構成した解へと変形できることがわかる。

したがって、横長行列を  j_1, \dots, j_n 番目の列ベクトルのみを切り出したものに探索範囲を絞ってもかまわないことが示された。

bitDP

あとは  O(n3^{n}) な bitDP で解くことができる。しかし、巡回操作に関する部分がめんどい。そしてここを雑にやると TLE する。

bitDP の遷移を考える前に、あらかじめ各列の各 bit について、巡回操作によって得られる利得の最大値を前処理して求めておくとよい。

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
using namespace std;

template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

int T;
int n, m;
vector<vector<long long> > a;
long long dp[5000];

long long solve() {
    if (n <= m) {
        vector<pair<long long, int> > v;
        for (int j = 0; j < m; ++j) {
            long long val = 0;
            for (int i = 0; i < n; ++i) chmax(val, a[i][j]);
            v.push_back({val, j});
        }
        sort(v.begin(), v.end(), greater<pair<long long,int> >());
        vector<vector<long long> > a2(n, vector<long long>(n, 0));
        for (int k = 0; k < n; ++k) {
            int j = v[k].second;
            for (int i = 0; i < n; ++i) {
                a2[i][k] = a[i][j];
            }
        }
        a = a2;
        m = n;
    }
    
    memset(dp, 0, sizeof(dp));
    for (int j = 0; j < m; ++j) {
        // 前処理
        vector<long long> score(1<<n, 0);
        for (int bit = 0; bit < (1<<n); ++bit) {
            for (int shift = 0; shift < n; ++shift) {
                long long tmp = 0;
                for (int i = 0; i < n; ++i)
                    if (bit & (1<<i))
                        tmp += a[(shift+i)%n][j];
                chmax(score[bit], tmp);
            }
        }
        
        // O(3^n) DP
        for (int bit = (1<<n)-1; bit >= 0; --bit) {
            int cbit = ((1<<n)-1) - bit;
            for (int bit2 = cbit; ; bit2 = (bit2 - 1) & cbit) {
                int nbit = bit | bit2;
                chmax(dp[nbit], dp[bit] + score[bit2]);
                if (!bit2) break;
            }
        }
    }
    return dp[(1<<n)-1];
}

int main() {
    cin >> T;
    for (int _ = 0; _ < T; ++_) {
        cin >> n >> m;
        a.assign(n, vector<long long>(m, 0));
        for (int i = 0; i < n; ++i) for (int j = 0; j < m; ++j) cin >> a[i][j];
        cout << solve() << endl;
    }
}