けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 058 E - 和風いろはちゃん (3D, 橙色, 700 点)

5 + 7 + 5 = 17 なの、よくできてる!

問題概要

正の整数  X, Y, Z が与えられる。

 1 以上  10 以下の数値からなる長さ  N の数列  a_{1}, \dots, a_{N} であって、以下の条件を満たすものの個数を 1000000007 で割ったあまりを求めよ。

  • 整数  0 \le x \lt y \lt z \lt w \le N が存在して、
  •  a の区間  \lbrack x, y) の総和が  X
  •  a の区間  \lbrack y, z) の総和が  Y
  •  a の区間  \lbrack z, w) の総和が  Z

制約

  •  3 \le N \le 40
  •  1 \le X \le 5
  •  1 \le Y \le 7
  •  1 \le Z \le 5

考えたこと

第一感だと、

  • 数列のうち、総和が  X, Y, Z となる区間の位置を決め打ちして
  • 区間内の総和が  X, Y, Z となるような個数を数える

という風にしたくなる。しかしそのような方針では、ダブルカウントの脅威を取り除くことがとても厳しそう。同じ数列でも、総和が  X, Y, Z となるような区間の取り方が幾通りもあるケースがあるのだ。そのようなものをすべて考慮するとなると気が遠くなる。

DP で状態を全部もつ

代わりに、DP しよう。そしてざっくりとした方針としては、「数列の今の状態において最後尾付近は結局どのような区間を形成しうるか」という情報をすべて押し込めてしまうのが良さそう。そこで、 S 0 以上  2^{X + Y + Z} 未満の整数として、

  • dp[ i ][ S ][ 0 or 1 ] := 数列のうち最初の i 項までを決めたとき、最後尾から連続する数値の総和として  1, 2, \dots, X+Y+Z のうちのどの値をとりうるかの情報を表すビット状態が  S であるようなもののうち、(0: まだ XYZ が形成されていない、1: すでに XYZ が形成済み) となるようなものの個数

という風にする。こうすると、ダブルカウントしそうな部位も全部まるごと DP 添字として管理できる。計算量は  O(AN2^{X+Y+Z}) ( A は数列の各項の値の種類数) となる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() const { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept {
        if (n == 0) return 1;
        if (n < 0) return modpow(modinv(r), -n);
        auto t = modpow(r, n / 2);
        t = t * t;
        if (n & 1) t = t * r;
        return t;
    }
    friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        return Fp<MOD>(u);
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int main() {
    int N, X, Y, Z;
    cin >> N >> X >> Y >> Z;
    int S = X+Y+Z;
    
    auto nex = [&](long long bit, int v) {
        long long nbit = (bit<<v) % (1<<S);
        if (v-1 < S) nbit |= 1<<(v-1);
        return nbit;
    };
    auto clear = [&](int bit) {
        if (!(bit & (1<<(Z-1)))) return false;
        if (!(bit & (1<<(Y+Z-1)))) return false;
        if (!(bit & (1<<(X+Y+Z-1)))) return false;
        return true;
    };

    vector<vector<mint>> dp(1<<S, vector<mint>(2, 0)), ndp = dp;
    dp[0][0] = 1;
    for (int i = 0; i < N; ++i) {
        ndp.assign(1<<S, vector<mint>(2, 0));
        for (int bit = 0; bit < (1<<S); ++bit) {
            for (int v = 1; v <= 10; ++v) {
                int nbit = nex(bit, v);
                if (clear(nbit)) ndp[nbit][1] += dp[bit][0];
                else ndp[nbit][0] += dp[bit][0];
                ndp[nbit][1] += dp[bit][1];
            }
        }
        swap(dp, ndp);
    }
    mint res = 0;
    for (int bit = 0; bit < (1<<S); ++bit) res += dp[bit][1];
    cout << res << endl;
}