これが生まれて初めての作問だった!!!
AOJ-ICPC で ☆4 個ついてて嬉しい!
問題概要
虫食算が与えられる。解の個数を 1000000007 で割ったあまりを求めたい。
より正確には長さ の 3 個の文字列 が与えられる。これらは '?' か '0'〜'9' の文字で構成されている。 の '?' に '0'〜'9' の数値を当てはめる方法のうち
- 完成された を整数値としてみたときに、 が成立
- の先頭は '0' ではない
という条件を満たすものの個数を求めよ。
制約
- の先頭の文字は '0' ではない
サンプル
3?4 12? 5?6
これの解は以下の 2 通りある。
- 384 + 122 = 506
- 394 + 122 = 516
解法
1 の位から順番に数値を当てはめていく (文字列としては末尾から見ていくことになるので、あらかじめ reverse しておく)。この際に「繰り上がり」の発生についても考慮していく。
さて、d 桁目まで数値を当てはめた状態では、以下のいずれかの状態がありうる。
- 繰り上がりが 0 (繰り上がりなし)
- 繰り上がりが 1
これらの情報を持って DP しよう。
- dp[ d ][ 0 or 1 ] := d 桁目まで数値を当てはめたときに、繰り上がりが (0 or 1) であるようにするときの場合の数
最終的な答えは dp[ N ][ 0 ] となる (最後は繰り上がってはいけない)。
DP
DP の遷移は、一見実装が大変だけど、次のように for 文三重ループを回してしまうと楽できる。下の実装では
- num0:d 桁目から d+1 桁目にかけて繰り上がりが発生しないように、d 桁目に数値を入れる方法の数
- num1:d 桁目から d+1 桁目にかけて繰り上がりが発生するように、d 桁目に数値を入れる方法の数
としている。これらの値を用いて、DP 遷移を行う。計算量は定数倍が重めの となる。
int num0 = 0, num1 = 0; for (int a = 0; a <= 9; ++a) { for (int b = 0; b <= 9; ++b) { for (int c = 0; c <= 9; ++c) { if (d 桁目で、A に a, B に b, C に c を当てはめるのが valid) { if (繰り上がりなし) ++num0; else ++num1; } } } }
コード
#include <bits/stdc++.h> using namespace std; // modint template<int MOD> struct Fp { long long val; constexpr Fp(long long v = 0) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr int getmod() const { return MOD; } constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp& r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp& r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp& r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr bool operator == (const Fp& r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp& r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream& is, Fp<MOD>& x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream& os, const Fp<MOD>& x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD>& r, long long n) noexcept { if (n == 0) return 1; if (n < 0) return modpow(modinv(r), -n); auto t = modpow(r, n / 2); t = t * t; if (n & 1) t = t * r; return t; } friend constexpr Fp<MOD> modinv(const Fp<MOD>& r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } return Fp<MOD>(u); } }; const int MOD = 1000000007; using mint = Fp<MOD>; mint solve(string A, string B, string C) { int N = (int)A.size(); reverse(A.begin(), A.end()); reverse(B.begin(), B.end()); reverse(C.begin(), C.end()); auto isvalid = [&](int a, int b, int c, int d, int kuriagari) { // 先頭に 0 はダメ if (d == N-1) if (a == 0 || b == 0 || c == 0) return false; // すでに入っている数値と矛盾がないか if (A[d] != '?' && A[d] != (char)('0'+a)) return false; if (B[d] != '?' && B[d] != (char)('0'+b)) return false; if (C[d] != '?' && C[d] != (char)('0'+c)) return false; // ok return true; }; vector<vector<mint>> dp(N+1, vector<mint>(2, 0)); dp[0][0] = 1; for (int d = 0; d < N; ++d) { for (int kuriagari = 0; kuriagari <= 1; ++kuriagari) { int num0 = 0, num1 = 0; for (int a = 0; a <= 9; ++a) { for (int b = 0; b <= 9; ++b) { for (int c = 0; c <= 9; ++c) { if (isvalid(a, b, c, d, kuriagari)) { if (a + b + kuriagari == c) ++num0; else if (a + b + kuriagari == c + 10) ++num1; } } } } dp[d+1][0] += dp[d][kuriagari] * num0; dp[d+1][1] += dp[d][kuriagari] * num1; } } return dp[N][0]; } int main() { string A, B, C; while (cin >> A) { if (A == "0") break; cin >> B >> C; cout << solve(A, B, C) << endl; } }