けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 130 E - Common Subsequence (青色, 500 点)

共通部分列に関する問題!!!!!
最長共通部分列問題は有名だけど、今回は共通部分列を数え上げる問題。

問題へのリンク

問題概要

2 つの数列  S, T が与えられる。 S T の共通部分列が何通りあるかを求めよ。

ただし、 S T から抜き取ってできる文字列が同じものであったとしても、抜き取る添字が異なるものは異なるものとしてみなすこととする。

制約

  •  1 \le N, M \le 2 \times 10^{3}

考えたこと

一目みて LCS (Longest Common Subsequence) を連想すると思う。この記事の問題 8 でもある。

典型的な DP (動的計画法) のパターンを整理 Part 1 ~ ナップサック DP 編 ~ - Qiita

これの DP 遷移を真似しようとすると一瞬詰まることになる。すなわち

  • dp[ i ][ j ] := S の最初 i 文字と、T の最初 j 文字とからできる共通部分列の個数

として、

  • dp[ i ][ j ] += dp[ i - 1 ][ j ]
  • dp[ i ][ j ] += dp[ i ][ j - 1 ]
  • S[ i-1 ] == T[ j-1] のとき、dp[ i ][ j ] += dp[ i - 1 ][ j - 1 ]

という風にしたくなるのだ。しかしこれだと答えが大きくなってしまうことにとまどった人も多そうである。。。

からくりは


S の i-1 文字目までと T の j-1 文字目までについての共通部分列 P について、この P は、dp[ i ][ j - 1 ] でも dp[ i - 1 ][ j ] でもカウントされている


ということだ!!!!!!
これを回避する道は 2 つある!!!

  1. ダブルカウントを引く
  2. dp の定義を工夫する

多くの場合の王道は 2 番目の方法だと思う (editorial でも 2 番目の方法)。ここでは両方とも検討してみる。

解法 1:ダブルカウントを引く

上記のことさえわかっていれば、実はほんの少し LCS の真似をした DP を変更すればよいのだ。

  • dp[ i ][ j ] += dp[ i - 1 ][ j ]
  • dp[ i ][ j ] += dp[ i ][ j - 1 ]

との間で、dp[ i - 1 ][ j - 1 ] から来ている分、つまり「 S の最初 i - 1 文字と T の最初 j - 1 文字の共通部分列であるやつらが、ちょうど dp[ i - 1 ][ j ] にも dp[ i ][ j - 1 ] にも取り込まれていて、これらが重複して数えられている」ということが問題なのだ。

よって、上記の処理に

  • dp[ i ][ j ] -= dp[ i - 1 ][ j - 1 ]

を追加するだけでよい!!!!!!!
なお下の実装は長いが、modint を使っているだけである。

#include <iostream>
#include <cstring>
#include <vector>
using namespace std;


// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;


int N, M;
vector<int> S, T;
mint dp[2100][2100];

int main() {
    cin >> N >> M;
    S.resize(N), T.resize(M);
    for (int i = 0; i < N; ++i) cin >> S[i];
    for (int i = 0; i < M; ++i) cin >> T[i];

    memset(dp, 0, sizeof(dp));
    dp[0][0] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = 0; j <= M; ++j) {
            if (i-1 >= 0 && j-1 >= 0 && S[i-1] == T[j-1]) dp[i][j] += dp[i-1][j-1];
            if (i-1 >= 0) dp[i][j] += dp[i-1][j];
            if (j-1 >= 0) dp[i][j] += dp[i][j-1];
            if (i-1 >= 0 && j-1 >= 0) dp[i][j] -= dp[i-1][j-1]; // ダブルカウントを除く
        }
    }
    cout << dp[N][M] << endl;
}

解法 2:DP の定義を工夫する

こういうダブルカウントが発生しそうな場面で使える常套テクニックがある。dp テーブルの定義をちょっと変える。

  • dp[ i ][ j ] := S の最初 i 文字と、T の最初 j 文字とからできる共通部分列のうち、S の最初の i 文字目のラスト (S[i-1]) と T の最初の j 文字目のラスト (T[j-1]) とが等しくて、かつこれを採用しているようなものの個数

とする。こうすると、dp 遷移は、S[ i - 1] == T[ j - 1 ] の部分について

  • dp[ i ][ j ] = sum_{0 <= x < i, 0 <= y < j} dp[ x ][ y ]

という風になる。つまり dp[ i 未満][ j 未満] の総和ということだ。とてもシンプルだ。このままだと  O(N^{2}M^{2}) の計算量がかかってしまうように思える。。。

しかし、これは二次元累積和を用いてあげることできれいに更新できるようになる。全体として  O(1) で解くことができる。

  • sdp[ i + 1 ][ j + 1 ] = sum_{0 <= x <= i, 0 <= y <= j} dp[ x ][ y ]

とする。こうすると先ほどの DP の式は

  • dp[ i ][ j ] = sdp[ i ][ j ]

と簡潔にかける。累積和の更新は

  • sdp[ i + 1 ][ j + 1 ] = sdp[ i + 1 ][ j ] + sdp[ i ][ j + 1 ] - sdp[ i ][ j ] + dp[ i ][ j ];

で OK。最後の答えは sdp[ N + 1 ][ M + 1 ] になる。

#include <iostream>
#include <cstring>
#include <vector>
using namespace std;


// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;


int N, M;
vector<int> S, T;
mint dp[2100][2100];
mint sdp[2100][2100];

int main() {
    cin >> N >> M;
    S.resize(N), T.resize(M);
    for (int i = 0; i < N; ++i) cin >> S[i];
    for (int i = 0; i < M; ++i) cin >> T[i];

    memset(dp, 0, sizeof(dp));
    memset(sdp, 0, sizeof(dp));
    dp[0][0] = 1;
    sdp[1][1] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = 0; j <= M; ++j) {
            if (i == 0 && j == 0) continue;
            if (i-1 >= 0 && j-1 >= 0 && S[i-1] == T[j-1]) {
                dp[i][j] = sdp[i][j];
            }
            sdp[i+1][j+1] = sdp[i+1][j] + sdp[i][j+1] - sdp[i][j] + dp[i][j];
        }
    }
    cout << sdp[N+1][M+1] << endl;
}