けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Tenka1 2019 D - Three Colors (600 点)

見た目とても面白そう

問題へのリンク

問題概要

 N 要素からなる数列  a があたえられる。各要素を「赤」「緑」「青」の三色のいずれかに塗る方法のうち、各色の合計値を  R, G, B として三辺の長さが  R, G, B となるような三角形が存在するようなものを数え上げよ。998244353 で割ったあまりで答えよ。

制約

  •  3 \le N \le 300
  •  1 \le a_i \le 300

考えたこと

いかにも部分和問題っぽい DP はする雰囲気の問題ではある。ちゃんと詰めるのは大変そうな雰囲気を感じる。

さて、三角形の成立条件は、数列  a の合計値を  S とするとき

  • 3 数のうちの最大値が  \frac{S}{2} より小さい
  •  R, G, B \ge 0

と言い表すことができる。一般性を失わずに  R が最大として、この最大値を決め打って、 G, B は好きなようにすればよい。。。と一瞬したくなるのだが、このとき  G B R を超えてしまう可能性があってダメ。このままでは数え辛い。

ということで補集合をとることにする。すなわち、三角形が成立しないような場合を数える。そうすると

  • 3 数のうちの最大値が  \frac{S}{2} 以上である、または
  •  R = 0 or  G = 0 or  B = 0

というようなものを数えることになる。そうすると今度こそ、 R を最大としたときに、 G, B は適当に決めたとしても、概ね  G B R 以上になることはない。ただ  R = G = \frac{S}{2} R = B = \frac{S}{2} といった例外がある。われわれが数えたい補集合を整理すると

  •  R \ge S/2 + 1 の場合」 × 3
  •  R = S/2 G, B >  0 の場合」 × 3 ( S が偶数の場合のみ)
  • 二辺が  S/2 で一辺が  0 の場合 ( S が偶数の場合のみ)

ということになる。これらを合計すればよい。

DP

  • dp[ i ][ v ] := 最初の i 個の塗り方のうち、赤の和が v になる場合の数 (緑 / 青の塗り方も考慮)
  • dp2[ i ][ v ] := 最初の i 個から何個か選んで合計を v にする場合の数

とすると、数え上げられる。dp2 はおなじみなのでいいとして、dp は

  • dp[ i + 1 ][ v + a[ i ] ] += dp[ i ][ v ] (赤)
  • dp[ i + 1 ][ v ] += dp[ i ][ v ] * 2 (青か緑)

という感じ。そうすると、求めたい補集合は

  • dp[ N ][ (S/2 + 1 以上) ] × 3
  • (dp[ N ][ S/2 ] - dp2[ N ][ S/2 ] × 2) × 3 ( S が偶数の場合のみ)
  • dp2[ N ][ S/2 ] × 3 ( S が偶数の場合のみ)

を合計したものになる。

#include <iostream>
#include <vector>
#include <set>
#include <algorithm>
using namespace std;


template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) v += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept {
        return is >> x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};


const int MOD = 998244353;
using mint = Fp<MOD>;


int main() {
    int N; cin >> N;
    vector<long long> a(N);
    long long S = 0;
    for (int i = 0; i < N; ++i) cin >> a[i], S += a[i];

    vector<vector<mint> > dp(N+1, vector<mint>(S+1, 0));
    dp[0][0] = 1;
    auto dp2 = dp; // 単純な部分和

    for (int i = 0; i < N; ++i) {
        for (int v = 0; v + a[i] <= S; ++v) {
            dp[i+1][v + a[i]] += dp[i][v];
            dp[i+1][v] += dp[i][v] * 2;

            dp2[i+1][v + a[i]] += dp2[i][v];
            dp2[i+1][v] += dp2[i][v];
        }
    }
    mint all = modpow(mint(3), N);
    mint hosyugo = 0;
    for (int i = S/2+1; i <= S; ++i) hosyugo += dp[N][i] * 3;
    if (S % 2 == 0) {
        hosyugo += (dp[N][S/2] - dp2[N][S/2] * 2) * 3;
        hosyugo += dp2[N][S/2] * 3;
    }
    cout << all - hosyugo << endl;
}