伝説の誤差問題!! 誤差について学べる、とても教育的な問題。
問題概要
人がコイントスをしました。各人には と番号がついています。 人目は、 回表が出て、 回裏が出ました。
人目のコイントスの成功率は と定義されます。
人 の番号を、成功率の高い順に並び替えてください。ただし、成功率が同じ人が複数いる場合、その中では人の番号が小さい順になるように並び替えてください。
制約
罠 1:double
型 (C++) では誤差 WA する
一見、次のようにすれば良いと思われるかもしれない。
{ // 添字 i, j を比較する関数 (大きい順に並べる) auto cmp = [&](int i, int j) -> bool { return (double)A[i] / (A[i] + B[i]) > (double)A[j] / (A[j] + B[j]); }; // 添字をソートする vector<int> ids(N); for (int i = 0; i < N; ++i) ids[i] = i; sort(ids.begin(), ids.end(), cmp); }
しかしこれだと実は誤差にやられてしまう。まず、double
型が表現できる相対誤差 (絶対値ではなく相対的な誤差) は 程度だと知られている。ここで、相対誤差という概念を簡単に整理しよう。
たとえば、 と を見分けるためには、 に対して、 の違いを見分ける必要がある。その見分けに要求される相対誤差は 分の 1......つまり、 よりも細かい相対誤差で表現できる必要があると言える。ここまでなら double
型でも見分け可能だ。
それでは、今回の問題で比較関数を作るときに、どの程度の相対精度 (相対誤差を見分けるための精度) が必要かを簡単に見積もってみよう。厳密な話はさておき、ざっくりとした理解でよければ
かけ算やわり算をすると、その値を見分けるための要求相対精度は乗算されていく
というように捉えて概ね問題ない (厳密な話は kyopro_friends さんのユーザ解説 より)。今回は、 という値を ごとに見分ける必要があるわけだが、
- という値を ごとに見分けるための要求相対精度は 程度
- という制約から、たとえば と という値を見分ける場面を考えるとよい
- という値を ごとに見分けるための要求相対精度も同様に 程度
となる。これらの割り算なので、結局、要求相対誤差は、これらを掛け合わせて 程度となる。一方、double
型で表現できる相対誤差は 程度なので足りないというわけだ。
よって、より精度の高い long double
型を使えば問題ない。
罠 2: の値が等しい部分
もう一つの罠は、 の値が等しい者同士は、番号が小さい順に並び替える必要がある。これはざっくり 2 つの方針がある
- 方針 1:安定ソートを用いる (C++ では
stable_sort()
) - 方針 2:タイの部分も考慮した比較関数を作る
なお、Python では普通の関数 sort()
がすでに安定ソートであることが保証されるので、それを用いて問題ない。下の実装例では、方針 1 を用いた。
コード
#include <bits/stdc++.h> using namespace std; int main() { int N; cin >> N; vector<long long> A(N), B(N); for (int i = 0; i < N; ++i) cin >> A[i] >> B[i]; // 添字 i, j を比較する関数 (大きい順に並べる) auto cmp = [&](int i, int j) -> bool { return (long double)A[i] / (A[i] + B[i]) > (long double)A[j] / (A[j] + B[j]); }; // 添字をソートする vector<int> ids(N); for (int i = 0; i < N; ++i) ids[i] = i; stable_sort(ids.begin(), ids.end(), cmp); // 出力 for (auto id : ids) cout << id+1 << " "; cout << endl; }
テクニック:整数型のみを用いる
更なるテクニックを。上では long double
型を用いて通したものの、それでもやっぱり誤差の不安はつきまとう。基本的に誤差のことを考えると、浮動小数点型を用いるよりは、整数型のみで完結できるなら整数型を用いたい。
今回も、実は比較関数をいじることで、整数型のみで完結できる。比較関数の中身である
を式変形しよう。分母を払うと、次のようになる。
この式は整数型のみで判定可能である!!
コード
#include <bits/stdc++.h> using namespace std; int main() { int N; cin >> N; vector<long long> A(N), B(N); for (int i = 0; i < N; ++i) cin >> A[i] >> B[i]; // 添字 i, j を比較する関数 (大きい順に並べる) auto cmp = [&](int i, int j) -> bool { return A[i] * (A[j] + B[j]) > A[j] * (A[i] + B[i]); }; // 添字をソートする vector<int> ids(N); for (int i = 0; i < N; ++i) ids[i] = i; stable_sort(ids.begin(), ids.end(), cmp); // 出力 for (auto id : ids) cout << id+1 << " "; cout << endl; }