けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

よくやる二項係数 (nCk mod. p)、逆元 (a^-1 mod. p) の求め方

1. 典型的な二項係数の求め方 (1 ≦ k ≦ n ≦ 107 程度)

競プロをしていると、nCk mod. p を計算する場面にしばしば出くわします。時と場合によって色んな方法が考えられますが、以下のものを頻繁に使用するイメージです。多くの AtCoder のトッププレイヤーたちも使用している形式でそれなりに高速です。5 年前のりんごさんのツイートが 1 つのきっかけとなって広まった印象があります:

使い方としては、最初に一度前処理として COMinit() をしておけば、あとは COM(n, k) 関数を呼べばよいです。

  • 前処理 COMinit():  O(n)
  • クエリ処理 COM(n, k):  O(1)
#include <iostream>
using namespace std;

const int MAX = 510000;
const int MOD = 1000000007;

long long fac[MAX], finv[MAX], inv[MAX];

// テーブルを作る前処理
void COMinit() {
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (int i = 2; i < MAX; i++){
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
        finv[i] = finv[i - 1] * inv[i] % MOD;
    }
}

// 二項係数計算
long long COM(int n, int k){
    if (n < k) return 0;
    if (n < 0 || k < 0) return 0;
    return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}

int main() {
    // 前処理
    COMinit();

    // 計算例
    cout << COM(100000, 50000) << endl;
}

注意

上記の実装は十分高速ではありますが、inv 配列が確実に不要な場面では、

  • fac 配列を計算する
  • finv[n] を逆元計算によって計算しておく
  • finv[i-1] = finv[i] * i % MOD によって finv 配列を後ろから計算していく

とした方が少し速いようです (@CuriousFairy315 さんより)

1-1. 使用可能場面

  •  1 \le k \le n \le 10^{7}
  •  p素数 (上の実装ではさらに  p > n を仮定している)

1-2. 使用原理

nCk = \frac{n!}{k!(n-k)!} = (n!) × (k!)^{-1} × ((n-k)!)^{-1}

であることを利用しています。COMinit() で、a! (fac[a]) と (a!)^{-1} (finv[a]) のテーブルを予め作っています。これを作っておくことで、クエリ計算が掛け算のみになって高速になります。

fac[0], fac[1], ..., fac[n-1] の計算が O(n) でできることは難しくない感じです。いわゆる累積和ならぬ累積積をやっている感じです。一方、finv[0], finv[1], ..., finv[n-1] の計算も実は O(n) でできることは驚きです。finv を計算するために、mod. p における 1, 2, ..., n の逆元 inv[1], inv[2], ..., inv[n] を  O(n) で求めます。そうすれば、inv の累積積をとることで finv も  O(n) で計算できます。

p素数としたとき、mod. p での逆元計算方法には大きく分けて

  • 拡張 Euclid の互除法
  • Fermat の小定理

とがあります。ともに O(\log{p}) かかりますが、多くの場合、拡張 Euclid の互除法の方が高速に動作します。さて、愚直に 1, 2, ..., n の逆元を計算していては O(n\log{p}) かかってしまいます。ところが

  • i の逆元を、p % i (これは i より小さいことに注意) の逆元を利用して O(1) で求める

という魔法のようなテクニックがあります。そのテクニックを用いることで、mod. p における 1, 2, ..., n の逆元を  O(n) で計算できます。そのことを理解するために、そもそも拡張 Euclid の互除法が何をしていたかを考えます。

1-3. 拡張 Euclid の互除法による逆元計算

 a^{-1} mod. p を計算するとはすなわち

 ax + py = 1

を満たす  x を求めたいということになります。Euclid の互除法を適用します。すなわち、 p a で割ってみます:

 p= qa + r

これを代入すると

 ax + (qa + r)y = 1 ⇔ ry + a(x + qy) = 1

になります。これによって  (a, p) に関する問題が、それより数値の小さな  (r, a) に関する問題に帰着できました。これを再帰的に解くのが拡張 Euclid の互除法です。具体的には  (r, a) に関する小問題を解いて

 rs + at = 1

と解  (s, t)再帰的に得られたとすると、

 y = s, x + qy = t ⇔ x = t - qs, y = s

という風に、元の問題の解を構成できます。下に  a^{-1} (mod. m) を求める実装を示します。なお注意点として、

  • 逆元を求める mod. m m素数でなくても、 a m が互いに素であればよい
  • 下の拡張 Euclid の互除法自体は  a b が互いに素でなくても適用できるが、逆元計算するときの  a m とは互いに素である必要がある

となっています。

#include <iostream>
using namespace std;

// ax + by = gcd(a, b) となるような (x, y) を求める
// 多くの場合 a と b は互いに素として ax + by = 1 となる (x, y) を求める
long long extGCD(long long a, long long b, long long &x, long long &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    long long d = extGCD(b, a%b, y, x); // 再帰的に解く
    y -= a / b * x;
    return d;
}

// 負の数にも対応した mod (a = -11 とかでも OK) 
inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

// 逆元計算 (ここでは a と m が互いに素であることが必要)
long long modinv(long long a, long long m) {
    long long x, y;
    extGCD(a, m, x, y);
    return mod(x, m); // 気持ち的には x % m だが、x が負かもしれないので
}

int main() {
    // 計算例
    cout << modinv(3, 7) << endl;
}

1-4. 改めて、inv[1], inv[2], ..., inv[n] を高速に計算する方法

拡張 Euclid の互除法におけるアイディアを少し変形して実現します。 a^{-1} を求めるために同じように

 ax + py = 1

を満たす  x を求めて行きます。 p a で割って

 p = qa + r

とするのですが、ここから上手いことやります。まず、 ax + py = 1 の両辺を  q 倍して変形していくと、

 qax + qpy = q ⇔ (p - r)x + qpy = q ⇔ rx + p(-x - qy) = -q

となります。拡張 Euclid の互除法では  (a, p) に関する問題を  (r, a) に関する問題に帰着していたのに対し、今回は  (r, p) に関する問題に帰着しています。 p を残していることがミソですね。さて、 z = -x - qy とおいて、 rx + pz = -q を満たす  x (とz) を求めることができれば万事解決ということになります。

 rx + pz = 1

を満たす  (x, z) (s, t) とすると、 rs + pt = 1 であり、これを両辺  -q 倍することで

 r(-sq) + p(-tq) = -q

となるので、 x = -sq, z = -tq rx + pz = -q を満たします。

  •  s = r^{-1} = (p %  a)^{-1}
  •  q = (p / a)

であることに注意すると、

 a^{-1} ≡ - (p %  a)^{-1} × (p / a) (mod.  p)

であることが導かれました。実装上は

inv[a] = MOD - inv[MOD % a] * (MOD / i) % MOD;

という風にします。

1-5. 逆元漸化式のもう 1 つの導出方法

上では拡張 Euclid の互除法を意識した導出を示しましたが、もっと直接的に導くこともできます。takapt さんの記事にもある導出方法です。 p a で割ると

 p = (p/a) × a + (p%a)

で、これを変形することによって直接導出することができます。この両辺の mod. p をとると、

 (p/a) × a + (p%a) ≡ 0

 ⇔ (p/a) + (p%a) × a^{-1} ≡ 0 (両辺に a^{-1} をかける)

 ⇔ (p%a) × a^{-1} ≡ -(p/a)

 ⇔ a^{-1} ≡ - (p% a)^{-1} × (p/a)

という風に簡潔に導かれます。

2. n がさらに巨大で固定値なとき (1 ≦ n ≦ 109, 1 ≦ k ≦ 107 程度)

 n が巨大なときは先程の手が使えません。しかし  k が小さければ望みがあります。

 nCk = \frac{n}{1} × \frac{n-1}{2} × ... × \frac{n-k+1}{k}

を利用して  O(k) で計算することができます。さらに  n が固定値の場合も多く、そんなときは配列テーブル

com[ k ] = nCk

 O(k) で前計算しておくことが有効です。

3. n も k も小さいとき ( 1≦ k ≦ n ≦ 2000 程度、mod. p が素数でなくてもよい)

動的計画法によって nCk のテーブルを生成する方法が有力で、mod の p が素数でなくてもよいのが魅力的です。

const long long MOD = 1000000007;
const int MAX_C = 1000;
long long Com[MAX_C][MAX_C];

void calc_com() {
    memset(Com, 0, sizeof(Com));
    Com[0][0] = 1;
    for (int i = 1; i < MAX_C; ++i) {
        Com[i][0] = 1;
        for (int j = 1; j < MAX_C; ++j) {
            Com[i][j] = (Com[i-1][j-1] + Com[i-1][j]) % MOD;
        }
    }
}

4. さらに

n と k がそれほど小さくなく mod が素数でない場合など、いくらでもイヤな場合は考えられますが、そこまで来たら uwi さんの記事を読めば大抵のことは解決します:

5. 二項係数を用いる問題たち

二項係数を用いる問題たちです。