けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

【ライブラリ】mod の値が大きいときの mod 演算

昨日の CSA 068 DIV2 E Sliding Product Sum で、まさかのときのために作っていたライブラリがドンピシャで役に立ちました!!!

「1000000007 で割った余りを答えよ」という問題は多いですが、「m が与えられて m で割った余りを答えよ」となると往々にして辛いですね。ましてや、m が int 型に収まらないとなると...


long long a = ?, b = ?
(a * b) % m


というなんでもない計算が long long 型オーバーフローしてしまいます。対策として、

a を b 回足しあげる、ただし繰り返し二乗法を流用する

というのがあります。a + a が long long 型オーバーフローしない限りは大丈夫です。

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}
inline long long mul(long long a, long long b, long long m) {
    a = mod(a, m); b = mod(b, m);
    if (b == 0) return 0;
    long long res = mul(mod(a + a, m), b>>1, m);
    if (b & 1) res = mod(res + a, m);
    return res;
}
inline long long inv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a = mod(a - mul(t, b, m), m); swap(a, b);
        u = mod(u - mul(t, v, m), m); swap(u, v);
    }
    return mod(u, m);
}

CSA 068 DIV2 E Sliding Product Sum


[1, 2, 3, …, N] の連続する k 要素の積としてありうる和をとる操作を 1 <= k <= K について行って、すべての総和をとって M で割ったあまりを求めよ。

・1 <= N <= 1018
・1 <= K <= min(600, N)
・1 <= M <= 1018


高校数学で習った和の中抜けを使います。そうすると

 (N + 1) N (N - 1) … (N - k + 1) / (k + 1)

を 1 <= k <= K について足しあげる問題になります。愚直に個別に足しあげても O(K2) なので大丈夫です。

ただし、M が大きい上に素数とは限らないので注意が必要です。基本的には掛け算も割り算も上のライブラリを使えばいいのですが、k + 1 と M が互いに素でないケースがいやです。しかし少し考えてみると、

N+1, N, N-1, ..., N-k+1

は連続する k + 1 個の整数なのでそのうちの 1 個が k + 1 で割り切れます。そこだけ k + 1 で割ってあげれば掛け算の問題になります。

#include <iostream>
using namespace std;

long long N, K, M;

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

inline long long mul(long long a, long long b, long long m) {
    a = mod(a, m); b = mod(b, m);
    if (b == 0) return 0;
    long long res = mul(mod(a + a, m), b>>1, m);
    if (b & 1) res = mod(res + a, m);
    return res;
}

long long solve() {
    long long res = 0;
    for (long long k = 1; k <= K; ++k) {
        long long tmp = 1;
        bool ok = false;
        for (long long num = N-k+1; num <= N+1; ++num) {
            long long tnum = num;
            if (!ok && num % (k+1) == 0) {
                ok = true;
                tnum = num / (k+1);
            }
            tmp = mul(tmp, tnum, M);
        }
        res += tmp;
        res %= M;
    }
    return res;
}


int main() {
    while (cin >> N >> K >> M) {
        cout << solve() << endl;
    }
}