けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 150 E - Change a Little Bit (500 点)

面白かった

問題へのリンク

問題概要

長さ  N の整数列  c_{0}, \dots, c_{N-1} が与えられる。

長さ  N の 0 と 1 からなる文字列  S, T に対して定まる関数  f(S, T) は次のようになっている。 f(S, T) は、次のようにして文字列  S を文字列  T に一致させるのに必要な最小コストとする。

  •  k (1, 2, \dots) 回目の操作で、 S の文字  S_{i} を選んで 0 は 1 に、1 は 0 にする
  • このときにかかるコストは  k \times c_{i} であって、これが加算される

あらゆる  S, T の組合せを考えたときの、 f(S, T) の総和を 1000000007 で割ったあまりを求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

まず、 S は 0000...0 としてかまわない。それ以外の場合であっても、0000...0 の場合に答えは一致する。よって、 S を 0000...0 に固定した場合の、 2^{N} 通りの  T に対して  f(S, T) の総和を求めて、最後にそれに  2^{N} をかければ OK。

さて、 T に 1 が  K 個あるとき、 K 回の操作を行うことになるが、それらの index を  i_{0}, \dots, i_{K-1} とすると

  •  1 \times c_{i_{0}}
  •  2 \times c_{i_{1}}
  • ...
  •  (K+1) \times c_{i_{K}}

を総和していくことになる。 c は大きい順にやった方がよいことがわかる。考察を簡単にするために、入力  c もあらかじめ大きい順にソートすることにする。

個別の要素に

 n (= 0, 1, \dots, N-1) 番目の要素について、それらが  k 回加算されるような  T が何個あるのかを考えることにする。

  •  1 番目に採用されるような  T C(n, 0)
  •  2 番目に採用されるような  T C(n, 1)
  • ...
  •  n+1 番目に採用されるような  T C(n, n)

ということになっている。よって、

  •  C_{n} \times (1 \times C(n, 0) + 2 \times C(n, 1) + \dots, (n+1) \times C(n, n)

を計算して、これを  n について合計すれば OK。二項係数について

  •  C(n, 0) + C(n, 1) + \dots + C(n, n) = 2^{n}
  •  0 \times C(n, 0) + 1 \times C(n, 1) + \dots + n \times C(n, n) = n 2^{n-1}

が成立する。前者は有名だが、後者はあまり有名でないかもしれない。後者は次のようにしてわかる。一般に

  •  C(n, k) = \frac{n}{k} C(n-1, k-1)

が成立することを利用すると、

 0 \times C(n, 0) + 1 \times C(n, 1) + \dots + n \times C(n, n)
 = n(C(n-1, 0) + C(n-1, 1) + \dots + C(n-1, n-1) = n 2^{n-1})

となる。以上から、

 1 \times C(n, 0) + 2 \times C(n, 1) + \dots, (n+1) \times C(n, n)
=  2^{n} + n2^{n-1}

となることがわかった。

式変形しなくても

二項係数計算は、式変形頑張らなくても、意味を考えればできるっぽい。解説放送より。

https://www.youtube.com/watch?v=9MphwmIsO7Q

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};

const int MOD = 1000000007;
using mint = Fp<MOD>;

int N;
vector<long long> c;

mint solve() {
    if (N == 1) return mint(c[0]) * 2;
    sort(c.begin(), c.end(), greater<long long>()); 
    mint res = 0;
    mint n2 = modpow(mint(2), N-2);
    mint n1 = modpow(mint(2), N-1);
    for (int n = 0; n < N; ++n) {
        mint fac = n2 * n + n1;
        res += fac * c[n];
    }
    return res * modpow(mint(2), N);
}

int main() {
    while (cin >> N) {
        c.resize(N);
        for (int i = 0; i < N; ++i) cin >> c[i];
        cout << solve() << endl;
    }
}