けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 361 F - x = a^b (1D, 青色, 500 点)

包除原理した!

問題概要

1 以上  N 以下の正整数  x であって、ある正整数  a と 2 以上の正整数  b を用いて  x = a^{b} と表現できるものはいくつありますか?

制約

  •  1 \le N \le 10^{18}

考えたこと

まず考えたのは、 b は素数のみ考えれば良いということだった。たとえば、

 2^{6} = 4^{3} = 8^{2}

というように、 b が合成数の場合、 a を取り直すことで  b が素数である場合に式変形できるのだ。また、 2^{61} \ge 10^{18} であることから、 b \le 60 の場合のみ考えればよい。

よって、考えるべき  b

 b = 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59

の 17 個に絞られることとなる。

包除原理

上記の 17 個の素数  b を考えれば良いのだが、異なる  b が同一の数を導くことがある。そこで包除原理することにする。

たとえば、 x^{2} の形でも書けて  y^{3} の形でも書けるような数は  z^{6} の形で書けることに注目する。

より一般に、 i = 1, 2, \dots, K に対して  x_{i}^{b_{i}} の形に書けるような数は、ある整数  X に対して  X^{b_{1}b_{2}\dots b_{K}} と書ける。

このことを用いて、上記 17 個の素数の  2^{17} 通りの部分集合に対して、その積を  B として、 N 以下の整数のうち  X^{B} の形で書ける数の個数を求めれば良いことがわかった。

コード

整数  N K 乗根を求める処理は、意外と誤差が怖い。ライブラリを作っておくとよくて、ライブラリの verify はこの問題でできる。

#include <bits/stdc++.h>
using namespace std;

// N < 2^64, K <= 64
uint64_t kth_root(uint64_t N, uint64_t K) {
    assert(K >= 1);
    if (N <= 1 || K == 1) return N;
    if (K >= 64) return 1;
    if (N == uint64_t(-1)) --N;
    
    auto mul = [&](uint64_t x, uint64_t y) -> uint64_t {
        if (x < UINT_MAX && y < UINT_MAX) return x * y;
        if (x == uint64_t(-1) || y == uint64_t(-1)) return uint64_t(-1);
        return (x <= uint64_t(-1) / y ? x * y : uint64_t(-1));
    };
    auto power = [&](uint64_t x, uint64_t k) -> uint64_t {
        if (k == 0) return 1ULL;
        uint64_t res = 1ULL;
        while (k) {
            if (k & 1) res = mul(res, x);
            x = mul(x, x);
            k >>= 1;
        }
        return res;
    };
    
    uint64_t res;
    if (K == 2) res = sqrtl(N) - 1;
    else if (K == 3) res = cbrt(N) - 1;
    else res = pow(N, nextafter(1 / double(K), 0));
    while (power(res + 1, K) <= N) ++res;
    return res;
}

int main() {
    vector<long long> prs = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59};
    long long N;
    cin >> N;
    
    long long res = 0;
    for (int bit = 1; bit < (1 << prs.size()); ++bit) {
        long long K = 1;
        for (int i = 0; i < prs.size(); ++i) {
            if (bit & (1 << i)) {
                K *= prs[i];
                if (K > 60) break;
            }
        }
        if (__builtin_popcount(bit) % 2 == 1) res += kth_root(N, K);
        else res -= kth_root(N, K);
    }
    cout << res << endl;
}