floor sum!!! コンテスト中に思いつけてよかった!
問題概要
皿 があって、皿 には 個の石が乗っている。また、空の袋がある。
あなたは以下の 2 種類の操作を好きな順番で 0 回以上何度でも行うことができる。
- 石が 1 個以上載っている皿全てから石を 1 個ずつ取り、取った石は袋に移動する
- 袋から石を 個取り出し、全ての皿に 1 個ずつ石を載せる。
操作後の各皿の石の個数のなす数列 としてありうる個数を、998244353 で割った余りを求めよ。
制約
考えたこと
たとえば、 を考える。作れる数は次のようにあんる。
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
- で袋 個 → これを起点とした加算で 個作れる
これを見ると、「等差数列の各項を で割った商の和」になっている。よって、floor sum が刺さる!
コード
#include <bits/stdc++.h> using namespace std; // sum_{i=0}^{n-1} floor((a * i + b) / m) // O(log(n + m + a + b)) // __int128 can be used for T template<class T> T floor_sum(T n, T a, T b, T m) { if (n == 0) return 0; T res = 0; if (a >= m) { res += n * (n - 1) * (a / m) / 2; a %= m; } if (b >= m) { res += n * (b / m); b %= m; } if (a == 0) return res; T ymax = (a * n + b) / m, xmax = ymax * m - b; if (ymax == 0) return res; res += (n - (xmax + a - 1) / a) * ymax; res += floor_sum(ymax, m, (a - xmax % a) % a, a); return res; } // #lp under (and on) the segment (x1, y1)-(x2, y2) // not including y = 0, x = x2 template<class T> T num_lattice_points(T x1, T y1, T x2, T y2) { T dx = x2 - x1; return floor_sum(dx, y2 - y1, dx * y1, dx); } // modint template<int MOD> struct Fp { // inner value long long val; // constructor constexpr Fp() noexcept : val(0) { } constexpr Fp(long long v) noexcept : val(v % MOD) { if (val < 0) val += MOD; } constexpr long long get() const noexcept { return val; } constexpr int get_mod() const noexcept { return MOD; } // arithmetic operators constexpr Fp operator - () const noexcept { return val ? MOD - val : 0; } constexpr Fp operator + (const Fp &r) const noexcept { return Fp(*this) += r; } constexpr Fp operator - (const Fp &r) const noexcept { return Fp(*this) -= r; } constexpr Fp operator * (const Fp &r) const noexcept { return Fp(*this) *= r; } constexpr Fp operator / (const Fp &r) const noexcept { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp &r) noexcept { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp &r) noexcept { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp &r) noexcept { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp &r) noexcept { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr Fp pow(long long n) const noexcept { Fp res(1), mul(*this); while (n > 0) { if (n & 1) res *= mul; mul *= mul; n >>= 1; } return res; } constexpr Fp inv() const noexcept { Fp res(1), div(*this); return res / div; } // other operators constexpr bool operator == (const Fp &r) const noexcept { return this->val == r.val; } constexpr bool operator != (const Fp &r) const noexcept { return this->val != r.val; } friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) noexcept { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) noexcept { return os << x.val; } friend constexpr Fp<MOD> modpow(const Fp<MOD> &r, long long n) noexcept { return r.pow(n); } friend constexpr Fp<MOD> modinv(const Fp<MOD> &r) noexcept { return r.inv(); } }; using i128 = __int128_t; const int MOD = 998244353; using mint = Fp<MOD>; int main() { long long N; cin >> N; vector<long long> A(N); for (int i = 0; i < N; ++i) cin >> A[i]; sort(A.begin(), A.end()); long long sum = 0; mint res = 0; for (int i = 0; i < N; ++i) { long long diff = A[i] - (i > 0 ? A[i-1] : 0); sum += diff * (N - i); long long num = (i+1 < N ? A[i+1] - A[i] : 1); i128 tmp = 0; if (i+1 < N) { tmp = floor_sum(i128(num), i128(N-i-1), i128(sum), i128(N)); tmp += num; } else { tmp = sum / N + 1; } long long add = tmp % MOD; res += add; } cout << res << endl; }