けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 152 E - Flatten (500 点)

「〜の最小値を求めてください」「ただし 1000000007 で割ったあまりで」
...この設定の歪さを見ると、不思議な気持ちになる

問題へのリンク

問題概要

 N 個の正の整数  A_{1}, A_{2}, \dots, A_{N} が与えられる。以下の条件を満たす正の整数列  B_{1}, B_{2}, \dots, B_{N} をすべて考えたときの、 B_{1} + B_{2} + \dots + B_{N} の最小値を求めよ (1000000007 で割ったあまりで)

  •  A_{1} \times B_{1} = A_{2} \times B_{2} = \dots = A_{N} \times B_{N} が成立

制約

  •  1 \le N \le 10^{4}
  •  1 \le A_{i} \le 10^{6}

考えたこと

1000000007 で割ったあまりを求めよ、というのは通常は「数え上げ問題」で問われる。「最小値を求めよ」という問題でそれが要求されるのは珍しい。

というのも、たとえば 6 と 1000000008 が答えの候補として絞れたとして、実際の答えは 6 であるば、mod をとってしまうとそれぞれ 6 と 1 になってしまって大小関係がよくわからなくなってしまう。

なので、「最小値を求めよ」という問題で 1000000007 で割ったあまりを要求されたら、かなりの確率で「答えを導くなんらかの計算式が導出できそう」という予感がする。その式が導けたならば、いつも通り、途中過程を 1000000007 で割りながら計算していけばよい。

さて、少し考えると、

  • 等しい数を  L = A_{1} \times B_{1} = A_{2} \times B_{2} = \dots = A_{N} \times B_{N} とおいてみる
  • このとき、 L A_{1}, A_{2}, \dots, A_{N} を割り切る、つまり  L A_{1}, A_{2}, \dots, A_{N} の公倍数である
  •  L は小さければ小さいほどよいので、 A_{1}, A_{2}, \dots, A_{N}最小公倍数とすればよい

ということがわかる。以上から、問題は以下のことに帰着された。


  •  A_{1}, A_{2}, \dots, A_{N} の最小公倍数を求めて  L として

  •  \frac{L}{B_{1}} + \frac{L}{B_{2}} + \dots + \frac{L}{B_{N}} が答え


しかし  L は 104 個もの値の最小公倍数ということで、ものすごく大きな整数になりうる。具体的には、仮に  A がすべて互いに素であるとすると、 6 \times 10^{4} 桁くらいになりうるわけだ ( A_{i} \le 10^{6} なので)。こんな  L を真っ向から扱うのは大変だ。

1. L を 1000000007 で割ったあまりを求める

まず大前提として、 \frac{L}{B_{1}} + \frac{L}{B_{2}} + \dots + \frac{L}{B_{N}} を計算したいと思ったとき、 L を 1000000007 で割ったあまりとしてしまうと、 B_{1} とかで割れなくなってしまうのでは...という不安を抱いた方は多いかもしれない。

しかし!!!その心配はない!!!!!mod. 1000000007 の世界で  L ÷ B_{i} を計算すればよいのだ!!!その考え方はここにも書いた

qiita.com

さて、3 つ以上の整数の最小公倍数の求め方を考えよう。たとえば

  •  120 = 2^{3} \times 3^{1} \times 5^{1}
  •  63 = 3^{2} \times 7^{1}
  •  250 = 2^{1} \times 5^{3}

の最小公倍数は、それぞれ素因数分解したときの各素数の「指数」の最大をとっていくことで求められ、

  •  L = 2^{3} \times 3^{2} \times 5^{3} \times 7^{1} = 63000

となる。まとめると


  •  A_{1}, A_{2}, \dots, A_{N}素因数分解をする
  • 各素因数ごとに指数の最大値を拾って掛け算していくことで、最小公倍数  L を求める (この過程で 1000000007 で割っていってよい)
  • mod. 1000000007 で  \frac{L}{B_{1}} + \frac{L}{B_{2}} + \dots + \frac{L}{B_{N}} を計算する

という風にして AC することができる。計算量は、 A_{i} の最大値を  M として  O(N\sqrt{M})

#include <iostream>
#include <vector>
#include <string>
using namespace std;
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl

// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
using mint = Fp<1000000007>;

vector<pair<long long, long long> > prime_factorize(long long n) {
    vector<pair<long long, long long> > res;
    for (long long p = 2; p * p <= n; ++p) {
        if (n % p != 0) continue;
        int num = 0;
        while (n % p == 0) { ++num; n /= p; }
        res.push_back(make_pair(p, num));
    }
    if (n != 1) res.push_back(make_pair(n, 1));
    return res;
}

int main() {
    int N; cin >> N;
    vector<long long> A(N);
    vector<long long> num(1100000, 0);
    for (int i = 0; i < N; ++i) {
        cin >> A[i];
        auto pf = prime_factorize(A[i]);
        for (auto p : pf) num[p.first] = max(num[p.first], p.second);
    }

    mint LCM = 1;
    for (int v = 2; v < 1100000; ++v) {
        LCM *= modpow(mint(v), num[v]);
    }

    mint res = 0;
    for (auto a : A) {
        res += LCM / a;
    }
    cout << res << endl;
}

2. 素因数分解を高速化する

今回は  A_{1}, A_{2}, \dots, A_{N} を愚直に素因数分解したが、ここを高速化することを考える。実は  M = 10^{6} くらいとして、  1, 2, 3, \dots, M をまとめて高速に素因数分解してしまう方法がある。大昔、osa_k 法とかよばれていた方法だ。エラトスネテスの篩が有効活用できる。

エラトステネスの篩は、 M 以下の素数を列挙する方法であるが、ついでに以下の配列を作成することができる:

  • min_factor[ n ] := n を割り切る最小の素数

ここまでの前処理を  O(N \log\log{N}) でできる。この前処理を行っておくと、 M 以下の整数  n素因数分解を、以下のように  O(\log{n}) でできる。

  • 「n を p = min_factor[ n ] で割れるだけ割る」というのを n = 1 となるまで繰り返す

たとえば n = 120 だったら、

  • min_factor[120] = 2 なので、2 で割れるだけ割ると 15 になる
  • min_factor[15] = 3 なので、3 で割れるだけ割ると 5 になる
  • min_factor[5] = 5 なので、5 で割れるだけ割ると 1 になる

という感じ。計算量は  O(M\log\log{M} + N \log{M})。なお、editorial には  O(M + N\log{M}) と書いてあるけど、これは、エラトステネスの篩の処理を線形時間で実行するマニアックなアルゴリズムを使った場合の話と思われる。

#include <iostream>
#include <vector>
using namespace std;


// modint: mod 計算を int を扱うように扱える構造体
template<int MOD> struct Fp {
    long long val;
    constexpr Fp(long long v = 0) noexcept : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr int getmod() { return MOD; }
    constexpr Fp operator - () const noexcept {
        return val ? MOD - val : 0;
    }
    constexpr Fp operator + (const Fp& r) const noexcept { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp& r) const noexcept { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp& r) const noexcept { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp& r) const noexcept { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp& r) noexcept {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp& r) noexcept {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp& r) noexcept {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp& r) noexcept {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr bool operator == (const Fp& r) const noexcept {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp& r) const noexcept {
        return this->val != r.val;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD>& x) noexcept {
        return os << x.val;
    }
    friend constexpr Fp<MOD> modpow(const Fp<MOD> &a, long long n) noexcept {
        if (n == 0) return 1;
        auto t = modpow(a, n / 2);
        t = t * t;
        if (n & 1) t = t * a;
        return t;
    }
};
using mint = Fp<1000000007>;
 

// エラトスネテスの篩
struct Eratos {
    vector<int> primes;
    vector<bool> isprime;
    vector<int> min_factor;

    Eratos(int MAX) : primes(),
                      isprime(MAX+1, true),
                      min_factor(MAX+1, -1) {
        isprime[0] = isprime[1] = false;
        min_factor[0] = 0, min_factor[1] = 1;
        for (int i = 2; i <= MAX; ++i) {
            if (!isprime[i]) continue;
            primes.push_back(i);
            min_factor[i] = i;
            for (int j = i*2; j <= MAX; j += i) {
                isprime[j] = false;
                if (min_factor[j] == -1) min_factor[j] = i;
            }
        }
    }

    // prime factorization
    vector<pair<int,int>> prime_factorize(int n) {
        vector<pair<int,int> > res;
        while (n != 1) {
            int prime = min_factor[n];
            int exp = 0;
            while (min_factor[n] == prime) {
                ++exp;
                n /= prime;
            }
            res.push_back(make_pair(prime, exp));
        }
        return res;
    }
};


int main() {
    int N; cin >> N;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    Eratos er(1100000);    
    vector<int> num(1100000, 0);
    for (int i = 0; i < N; ++i) {
        auto pf = er.prime_factorize(A[i]);
        for (auto p : pf) num[p.first] = max(num[p.first], p.second);
    }
 
    mint LCM = 1;
    for (int v = 2; v < 1100000; ++v) {
        LCM *= modpow(mint(v), num[v]);
    }
 
    mint res = 0;
    for (auto a : A) {
        res += LCM / a;
    }
    cout << res << endl;
}