けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

フォルシアゆるふわ競プロオンサイト #3 E - Sweets Distribution(Hard)

面白かった。セグ木にこういうの乗っけるの楽しい!

問題へのリンク

問題概要

 4 \times N の盤面の各マスに整数値が書かれている。このマスに対して、適切に  0 \lt i \lt j \lt k \le N を決めて、

  • 盤面の 0 行目の区間  \lbrack 0, i) の総和
  • 盤面の 1 行目の区間  \lbrack i, j) の総和
  • 盤面の 2 行目の区間  \lbrack j, k) の総和
  • 盤面の 3 行目の区間  \lbrack k, N) の総和

の総和の最大値を求めるタスクを考える。今、 Q 個のクエリが与えられ、 i 番目のクエリは盤面の  l_{i} 列目と  r_{i} 列目を swap する。各クエリ終了後の盤面に対して、上記の最大値を求めよ。

制約

  •  4 \le N \le 2 \times 10^{5}
  •  1 \le Q \le 2 \times 10^{4}

考えたこと

まず  Q = 1 の場合は単純な DP で解くことができる。さて、クエリ処理前後で盤面や、DP の様相がどのように変化するのかをちょっと考えてみよう。

a 列目と b 列目とを入れ替えるとき、a 列目や b 列目を含まない区間については、i 行目から j 行目までいたる最適経路は特に変化しないことがわかる。a 行目と b 行目を含むような区間については変化する。その辺りのことを考えると、なんとなくセグメント木に乗るのでは...という気がしてくる。

実際乗るのだ。セグメント木の各区間 [left, right) には以下の 4 × 4 の情報を持たせるのだ。

  • left 列目から right 列目まで (区間 [left, right)) に対して
  • v[ i ][ j ] := (i, left) から出発して (j-1, right-1) に到着するまでの最適経路

とする。こうしておくと、区間の併合をこんな風に記述することができる。左側の情報を a、右側の情報を b としたとき、0 <= i < j <= 4 に対して、

  • chmax(res[ i ][ j ], a[ i ][ k ] + b[ k ][ j ]);
  • chmax(res[ i ][ j ], a[ i ][ k ] + b[ k - 1 ][ j ]);

という感じに併合できる。以上をもとにセグメント木で解く。計算量は  O(Q \log{N}) となる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

template<class Monoid> struct SegTree {
    using Func = function<Monoid(Monoid, Monoid)>;
    const Func F;
    const Monoid UNITY;
    int SIZE_R;
    vector<Monoid> dat;
    
    SegTree(int n, const Func f, const Monoid &unity): F(f), UNITY(unity) { init(n); }
    void init(int n) {
        SIZE_R = 1;
        while (SIZE_R < n) SIZE_R *= 2;
        dat.assign(SIZE_R * 2, UNITY);
    }
    
    /* set, a is 0-indexed */
    void set(int a, const Monoid &v) { dat[a + SIZE_R] = v; }
    void build() {
        for (int k = SIZE_R - 1; k > 0; --k)
            dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* update a, a is 0-indexed */
    void update(int a, const Monoid &v) {
        int k = a + SIZE_R;
        dat[k] = v;
        while (k >>= 1) dat[k] = F(dat[k*2], dat[k*2+1]);
    }
    
    /* get [a, b), a and b are 0-indexed */
    Monoid get(int a, int b) {
        Monoid vleft = UNITY, vright = UNITY;
        for (int left = a + SIZE_R, right = b + SIZE_R; left < right; left >>= 1, right >>= 1) {
            if (left & 1) vleft = F(vleft, dat[left++]);
            if (right & 1) vright = F(dat[--right], vright);
        }                                                                                                              
        return F(vleft, vright);
    }
    inline Monoid operator [] (int a) { return dat[a + SIZE_R]; }
    
    /* debug */
    void print() {
        for (int i = 0; i < SIZE_R; ++i) {
            cout << (*this)[i];
            if (i != SIZE_R-1) cout << ",";
        }
        cout << endl;
    }
};

const long long INF = 1LL<<40;
int N, Q;
vector<vector<long long>> v;

using Node = vector<vector<long long>>;
Node make(vector<long long> unit) {
    Node res(5, vector<long long>(5, -INF));
    for (int i = 0; i < 4; ++i) res[i][i+1] = unit[i];
    return res;
}

void solve() {
    Node unity(5, vector<long long>(5, -1)); // 単位元
    auto func = [&](const Node &a, const Node &b) {
        if (a == unity) return b;
        else if (b == unity) return a;
        Node res(5, vector<long long>(5, -INF));
        for (int i = 0; i < 4; ++i) {
            for (int j = i+1; j <= 4; ++j) {
                for (int k = i+1; k <= j; ++k) {
                    chmax(res[i][j], a[i][k] + b[k][j]);
                    chmax(res[i][j], a[i][k] + b[k-1][j]);
                }
            }
        }
        return res;
    };
    SegTree<Node> seg(N+1, func, unity);
    for (int i = 0; i < N; ++i) seg.set(i, make(v[i]));
    seg.build();

    while (Q--) {
        int l, r;
        cin >> l >> r;
        --l, --r;
        swap(v[l], v[r]);
        seg.update(l, make(v[l]));
        seg.update(r, make(v[r]));
        cout << seg.get(0, N)[0][4] << endl;
    }
}

int main() {
    cin >> N >> Q;
    v.assign(N, vector<long long>(4));
    for (int i = 0; i < 4; ++i) for (int j = 0; j < N; ++j) cin >> v[j][i];
    solve();
}