けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 154 D - Dice in Line (茶色, 400 点)

期待値の線形性!!!

問題へのリンク

問題概要

 N 個のサイコロが左から右に一列に並べてある。 i 番目のサイコロは目が  1, 2, \dots, p_{i} となっていて、これらが当確率に出る。

隣接する  K 個のサイコロを選んでそれぞれ独立に振ったとき、出る目の合計の期待値の最大値を求めよ。

制約

  •  1 \le K \le N \le 2 \times 10^{5}

考えたこと

出た!!!!!!!!期待値!!!!!!!!!!!

まず隣接する  K 個のサイコロを選ぶ方法を最適化せよと言われているけど、とりあえずまずは以下の問題を解いてみよう。元の問題を解くためには、まずはその部分的な問題を解いていくのがいいと思う。


 K 個のサイコロがあって、それぞれ目の種類は  p_{1}, p_{2}, \dots, p_{K} だけある。

 K 個のサイコロを振って出た目の和の期待値を求めよ。


まずサイコロ 1 個 (目の種類が  p) だった場合は、

  •  1 の目が出る確率は  \frac{1}{p}
  •  2 の目が出る確率は  \frac{1}{p}
  • ...
  •  p の目が出る確率は  \frac{1}{p}

となるので、期待値は

 1 \times \frac{1}{p} + 2 \times \frac{1}{p} + \dots + p \times \frac{1}{p}
=  \frac{p+1}{2}

となる。なおもっと簡単に、 1, 2, ..., p の平均は  \frac{p+1}{2} であることは直感的にも納得できる。つまり最小が 1 で、最大が p で、その中点である  \frac{p+1}{2} が平均というわけだ。

期待値の線形性

それでは  K 個のサイコロの目の合計値の期待値はどうだろうか。実はこれはこんな風に簡単に考えることができる。

  • 1 個目のサイコロの出す目の平均値は  \frac{p_{1} + 1}{2}
  • 2 個目のサイコロの出す目の平均値は  \frac{p_{2} + 1}{2}
  • ...
  •  K 個目のサイコロの出す目の平均値は  \frac{p_{K} + 1}{2}

よって、これらを合計して、 K 個のサイコロの目の合計値の期待値は

  •  \frac{p_{1} + 1}{2} + \frac{p_{2} + 1}{2} + \dots + \frac{p_{K} + 1}{2}

となる。

元の問題へ

こうして  N 個のサイコロから  K 個を選んだときの期待値を求めることには成功した。最後の問題は、この値が最大となるように  K 個を選びとることだ。

これは  N 個の値

 \frac{p_{1} + 1}{2}, \frac{p_{2} + 1}{2}, \dots, \frac{p_{N} + 1}{2}

の中から連続する  K 個の総和の最大値を求めればよい。ここまできたら累積和で求めることができる。累積和についてはここに。

qiita.com

なお、実装上の工夫として、 N 個の値

 \frac{p_{1} + 1}{2}, \frac{p_{2} + 1}{2}, \dots, \frac{p_{N} + 1}{2}

を一様に 2 倍して

 p_{1} + 1, p_{2} + 1, \dots, p_{N} + 1

という風にしても、最大となる  K 個の値は変わらない。最後に /2 をすればよい。こうすれば、最後の最後に /2 するまでは整数型のみで計算することができる。

#include <bits/stdc++.h>
using namespace std;

int main() {
    int N, K; cin >> N >> K;
    vector<long long> p(N);
    for (int i = 0; i < N; ++i) cin >> p[i], ++p[i]; // 1 足しておく

    // 累積和
    vector<long long> s(N+1, 0);
    for (int i = 0; i < N; ++i) s[i+1] = s[i] + p[i];

    // K 差を見ていく
    long long res = 0;
    for (int i = 0; i+K <= N; ++i) res = max(res, s[i+K] - s[i]);
    cout << fixed << setprecision(10) << (double)(res)/2 << endl;
}