けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 213 H - Stroll (赤色, 600 点)

オンライン分割統治 FFT 面白いね!!
公式解説がとても丁寧なので、備忘録程度に

問題概要

頂点数  N のが与えられます。 M 組の頂点対に対して、次のように無向辺を張っていきます。 i 組めの頂点対に対しては

  • 長さ 1 の辺が  p_{i, 1}
  • 長さ 2 の辺が  p_{i, 2}
  • ...
  • 長さ  T の辺が  p_{i, T}

というように辺を張っていきます。このように作られたグラフにおいて、頂点 0 から出発して頂点 0 へと戻ってくる長さ  T のウォークの本数を 998244353 で割ったあまりを求めてください。

制約

  •  2 \le N \le 10
  •  1 \le T \le 4 \times 10^{4}

まずは DP

いかにも計算量的に間に合わないことはわかりつつも、まずは愚直な DP を立てることが大事な気がする。

dp[v][t] ← 距離  t だけ進んで、頂点  v に到達するような場合の数

このとき、 v を終点にもつような各辺  e = (u, v) に対して、

dp[v][t] +=  \displaystyle \sum_{i = 0}^{t-1} dp[u][t - i] × p[e][i]

という遷移が立てられる。この時点で  O(MT^{2}) にはなっているけど、とても間に合わないということで悩んでいた。そのジレンマは公式解説にまさに書かれていた。

僕もこの式が畳み込みっぽいとは思っていて、でも dp[t] の計算に dp[0], dp[1], ..., dp[t-1] も必要で難しいな...となった。それと類似の状況は下に貼る問題でも生じていた。畳み込みの添字に i < j という順序関係があって難しいという問題だった。そしてそのときに FFT が役に立ったことを思い出してはいた。なので今回も FFT 的にできないかな...という気持ちはあった。

drken1215.hatenablog.com

ただそのあとを詰められなかった。コンテスト後に解説を見て、分割統治 FFT (オンライン FFT) という考え方を知った。上に貼った問題もオンライン FFT とは違うけど、似た感じなのかなと思った。分割統治 FFT の解説については公式解説がとても丁寧なのでそちらに。

備忘録として、統治部分の式を。

t = mid, mid + 1, ..., right - 1 に対して

dp[v][t] +=  \displaystyle \sum_{i = {\rm left}}^{{\rm mid}} dp[u][i] × p[e][t - i]

計算量は  O(M T (\log T)^{2}) となる。なんか順序っぽい構造が入った FFT は分割統治法を使うとうまくいくことがあるよ、ということで頭に留めたいと思う。

コード

#include <bits/stdc++.h>
using namespace std;

#include "atcoder/convolution.hpp"
#include "atcoder/modint.hpp"
using namespace atcoder;
using mint = modint998244353;

int main() {
    // 入力
    int N, M, T;
    cin >> N >> M >> T;
    vector<int> A(M * 2), B(M * 2);
    vector<vector<mint>> P(M * 2, vector<mint>(T + 1));
    for (int i = 0; i < M; ++i) {
        cin >> A[i] >> B[i];
        --A[i], --B[i];
        B[i + M] = A[i], A[i + M] = B[i];
        for (int t = 1; t <= T; ++t) {
            long long p;
            cin >> p;
            P[i + M][t] = P[i][t] = p;
        }
    }

    // 分割統治 FFT
    vector<vector<mint>> dp(N, vector<mint>(T + 1, 0));
    dp[0][0] = 1;
    auto rec = [&](auto self, int left, int right) -> void {
        if (right - left <= 1) return;

        int mid = (left + right) / 2;

        // まず左半分を更新
        self(self, left, mid);

        // 左半分から右半分への遷移を更新
        for (int e = 0; e < M * 2; ++e) {
            int u = A[e], v = B[e];

            vector<mint> L(mid - left, 0), R(right - left, 0);
            for (int t = left; t < mid; ++t) L[t - left] = dp[u][t];
            for (int t = 0; t < right - left; ++t) R[t] = P[e][t]; 
            
            auto seki = convolution(L, R);
            for (int t = mid; t < right; ++t) {
                dp[v][t] += seki[t - left];
            }
        }
        
        // 最後に右半分を更新
        self(self, mid, right);
    };
    rec(rec, 0, T + 1);
    cout << dp[0][T].val() << endl;
}