けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ARC 074 E - RGB Sequence (800 点)

わかってる...!わかってるんだ!!!この問題が超ド典型だってことくらい!!!!!!!!

でも典型だからって、そんなパッと解けるわけじゃない。すごく苦手なんだこういうの。。。

問題へのリンク

問題概要

 N マスを RGB の三色で塗り分ける  3^{N} 通りの方法のうち、以下の  M 個の条件をすべて満たすものの個数を 1000000007 で割ったあまりを求めよ。

  •  l_i, r_i が与えられて、区間  \lbrack l_i, r_i \rbrack 内に登場する色の種類数はちょうど  x_i である ( x_i = 1, 2, 3)

制約

  •  1 \le N, M \le 300

考えたこと

区間について条件を満たしながらどうのこうのしていく問題。こういうのは

  • 区間の終端でソートしておく
  • dp[ i ] := 最初の i マス分についての何か、みたいな DP を立てておいて、i を終端とする区間を順に処理していく

みたいな DP をすればよいと相場は決まってる。

DP の詰め

  • 最後に色を塗ったところからどこまで一色しかないか
  • 最後に色を塗ったところからどこまで二色しかないか

といった情報があれば DP の更新ができそう。なので

  • dp[ i ][ j ][ k ] := 最初の k マスを色ぬりしていて、マス k-1 と異なる色が最初に登場するのがマス j-1 で、マス k-1 とも j-1 とも異なる色が最初に登場するのがマス i-1 であるようなもののうち、右端が k 以下であるような区間に関する条件は満たすようなものの個数

とするとよさそう。そして k マス目に新たに色を塗るときに、

  • k-1 マス目に塗ったやつを塗って大丈夫ならそれを塗る
  • j-1 マス目に塗ったやつを塗って大丈夫ならそれを塗る
  • i-1 マス目に塗ったやつを塗って大丈夫ならそれを塗る

という三種類の分岐をしていく。計算量は

  • DP テーブルのノード数  O(N^{3})
  • DP 遷移数  O(N^{2}M)
#include <iostream>
#include <vector>
#include <map>
#include <cstring>
#include <numeric>
#include <algorithm>
using namespace std;

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MAX = 310;
const int MOD = 1000000007;
using mint = Fp<MOD>;

int N, M;
using pint = pair<int,int>;
vector<pint> inters[MAX];

bool check(int i, int j, int l, int x) {
    int num = 1;
    if (l < i) ++num;
    if (l < j) ++num;
    return num == x;
}

mint dp[MAX][MAX][MAX];
mint solve() {
    memset(dp, 0, sizeof(dp));
    dp[0][0][0] = 1;
    for (int i = 0; i <= N; ++i) {
        for (int j = i; j <= N; ++j) {
            for (int k = j; k <= N; ++k) {
                int r = k+1;
                // i -> r+1
                {
                    bool ok = true;
                    for (auto p : inters[r]) if (!check(j, k, p.first, p.second)) ok = false;
                    if (ok) dp[j][k][r] += dp[i][j][k];
                }
                // j -> r+1
                {
                    bool ok = true;
                    for (auto p : inters[r]) if (!check(i, k, p.first, p.second)) ok = false;
                    if (ok) dp[i][k][r] += dp[i][j][k];
                }
                // k -> r+1
                {
                    bool ok = true;
                    for (auto p : inters[r]) if (!check(i, j, p.first, p.second)) ok = false;
                    if (ok) dp[i][j][r] += dp[i][j][k];
                }
            }
        }
    }
    mint res = 0;
    for (int i = 0; i <= N; ++i) for (int j = 0; j <= N; ++j) res += dp[i][j][N];
    return res;
}

int main() {
    cin >> N >> M;
    for (int i = 0; i < M; ++i) {
        int l, r, x; cin >> l >> r >> x; --l;
        inters[r].push_back({l, x});
    }
    for (int r = 0; r < MAX; ++r) sort(inters[r].begin(), inters[r].end());
    cout << solve() << endl;
}