お 見た目とても面白そう
問題概要
要素からなる数列 があたえられる。各要素を「赤」「緑」「青」の三色のいずれかに塗る方法のうち、各色の合計値を として三辺の長さが となるような三角形が存在するようなものを数え上げよ。998244353 で割ったあまりで答えよ。
制約
考えたこと
いかにも部分和問題っぽい DP はする雰囲気の問題ではある。ちゃんと詰めるのは大変そうな雰囲気を感じる。
さて、三角形の成立条件は、数列 の合計値を とするとき
- 3 数のうちの最大値が より小さい
と言い表すことができる。一般性を失わずに が最大として、この最大値を決め打って、 は好きなようにすればよい。。。と一瞬したくなるのだが、このとき や が を超えてしまう可能性があってダメ。このままでは数え辛い。
ということで補集合をとることにする。すなわち、三角形が成立しないような場合を数える。そうすると
- 3 数のうちの最大値が 以上である、または
- or or
というようなものを数えることになる。そうすると今度こそ、 を最大としたときに、 は適当に決めたとしても、概ね や が 以上になることはない。ただ や といった例外がある。われわれが数えたい補集合を整理すると
- 「 の場合」 × 3
- 「 で > の場合」 × 3 ( が偶数の場合のみ)
- 二辺が で一辺が の場合 ( が偶数の場合のみ)
ということになる。これらを合計すればよい。
DP
- dp[ i ][ v ] := 最初の i 個の塗り方のうち、赤の和が v になる場合の数 (緑 / 青の塗り方も考慮)
- dp2[ i ][ v ] := 最初の i 個から何個か選んで合計を v にする場合の数
とすると、数え上げられる。dp2 はおなじみなのでいいとして、dp は
- dp[ i + 1 ][ v + a[ i ] ] += dp[ i ][ v ] (赤)
- dp[ i + 1 ][ v ] += dp[ i ][ v ] * 2 (青か緑)
という感じ。そうすると、求めたい補集合は
- dp[ N ][ (S/2 + 1 以上) ] × 3
- (dp[ N ][ S/2 ] - dp2[ N ][ S/2 ] × 2) × 3 ( が偶数の場合のみ)
- dp2[ N ][ S/2 ] × 3 ( が偶数の場合のみ)
を合計したものになる。
#include <iostream> #include <vector> #include <set> #include <algorithm> using namespace std; template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) v += MOD; } constexpr int getmod() { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b; swap(a, b); u -= t * v; swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD>& x) noexcept { return is >> x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept { if (n == 0) return 1; auto t = modpow(a, n / 2); t = t * t; if (n & 1) t = t * a; return t; } }; const int MOD = 998244353; using mint = Fp<MOD>; int main() { int N; cin >> N; vector<long long> a(N); long long S = 0; for (int i = 0; i < N; ++i) cin >> a[i], S += a[i]; vector<vector<mint> > dp(N+1, vector<mint>(S+1, 0)); dp[0][0] = 1; auto dp2 = dp; // 単純な部分和 for (int i = 0; i < N; ++i) { for (int v = 0; v + a[i] <= S; ++v) { dp[i+1][v + a[i]] += dp[i][v]; dp[i+1][v] += dp[i][v] * 2; dp2[i+1][v + a[i]] += dp2[i][v]; dp2[i+1][v] += dp2[i][v]; } } mint all = modpow(mint(3), N); mint hosyugo = 0; for (int i = S/2+1; i <= S; ++i) hosyugo += dp[N][i] * 3; if (S % 2 == 0) { hosyugo += (dp[N][S/2] - dp2[N][S/2] * 2) * 3; hosyugo += dp2[N][S/2] * 3; } cout << all - hosyugo << endl; }