けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

TopCoder SRM 735 DIV1 Medium - QuadraticIdentity

ちょうど中国剰余定理シリーズやっていたので、ピンポイントだった。

問題概要

整数  m が与えられる。

 x^{2} ≡ x (mod.  m)

を満たすような  x (0 \le x <  m) をすべて求め、そのサイズが 500 より大きい場合には 500 以下になるまで以下の操作をした上で出力せよ:

  • 数列を小さい順にソートして奇数番目のみを残す

制約

  •  1 \le m \le 10^{15}

解法

サイズが 500 より大きかったら...という話が怪しい雰囲気だけど、単に topcoder の出力受け取りキャパシティの問題っぽい???全然問題とは関係ない感じっぽい。

さて、

 x(x-1) ≡ 0 (mod.  m)

になる。 x(x-1) m の倍数というわけだが、x と x-1 とは互いに素なので、

 m = p_1^{k_1} p_2^{k_2} \dots p_n^{k_n}

と素因数分解したときに、 x x-1 とは共通の素因子をもたない。したがって、 n 個の素数べき  p_1^{k_1}, p_2^{k_2}, \dots, p_n^{k_n} x x-1 とに振り分ける感じになる。その方法は  2^{n} 通りある。振り分けて

  •  x ≡ 0 (mod.  A)
  •  x ≡ 1 (mod.  B)

になったとする。これを満たす  x (mod.  AB (=m)) は、中国剰余定理を用いて求めることができる。ライブラリで殴ってもいいのだが、もっと簡単に求めることもできる。

mod.  B での  A の逆元を  a として、

  •  x = aA

としてあげると、 x ≡ 0 (mod.  A) と、 x ≡ 1 (mod.  B) をともに満たす。

最後に素因数分解したときの素因子の種類数  n がどの程度の大きさになり得るかを見積もっておく。Python でサッと計算すると

>>> 2*3*5*7*11*13*17*19*23*29*31*37*41
304250263527210
>>> 2*3*5*7*11*13*17*19*23*29*31*37*41*43
13082761331670030

となったので、最悪でも  n \le 13 であることがわかる。したがって、求める  x の個数は最悪でも  2^{13} = 8192 個であり、大分余裕がある。

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

vector<long long> prime_factor(long long n) {
    vector<long long> res;
    for (long long i = 2; i*i <= n; ++i) {
        long long num = 1;
        if (n % i == 0) {
            while (n%i == 0) {
                num *= i;
                n /= i;
            }
            res.push_back(num);
        }
    }
    if (n != 1) res.push_back(n);
    return res;
}

inline long long mod(long long a, long long m) {
    return (a % m + m) % m;
}

inline long long inv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    return mod(u, m);
}

class QuadraticIdentity {
public:
    vector<long long> getFixedPoints(long long m) {
        vector<long long> pf = prime_factor(m);
        vector<long long> list;
        int n = (int)pf.size();
        for (int bit = 0; bit < (1<<n); ++bit) {
            long long A = 1, B = 1;
            for (int i = 0; i < pf.size(); ++i) {
                if (bit & (1<<i)) A *= pf[i];
                else B *= pf[i];
            }
            long long tmp = inv(A, B) * A;
            list.push_back(tmp);
        }
        sort(list.begin(), list.end());
        
        while (list.size() > 500) {
            vector<long long> nlist;
            for (int i = 0; i < (int)list.size(); i += 2) nlist.push_back(list[i]);
            list = nlist;
        }
        
        return list;
    }
};