けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AISing Programming Contest 2020 D - Anything Goes to Zero (水色, 400 点)

結構難しい!!

問題概要

正の整数  n に対して、

  •  p(n) :=  n を二進法表現したときの各桁の総和を  s として  n s で割ったあまり
  •  f(n) :=  n f(n) で置き換える操作を繰り返したときに、何回で 0 になるか

として定める。たとえば  n = 7 のとき、 p(7) = 1,  p(1) = 0 より、 f(n) = 2 となる。

今、二進法表記で  N 桁の整数  S が与えられる。各  i = 0, 1, \dots, N-1 に対して、 S の上から  i 桁目をビット反転した値  S_{i} に対する  f(S_{i}) の値を求めよ。

制約

  •  1 \le N \le 2 \times 10^{5}

考えたこと

とりあえず、とりあえず  S に対して一回  p(S) をとると  N 以下の整数にはなる。よって、 f(n) の値は小さそうだという予想はつく。1 つの整数に対して求めるだけなら愚直にシミュレーションしてもよさそう。

また、 S を各桁ごとにビット反転することにる変化は

  •  p(S) の値は ±1 しか変化しない
  •  S 自体の値も  2^{i} しか変化しない

ということがわかる。よって、この差分は高速に計算できる。

#include <bits/stdc++.h>
using namespace std;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

long long p(long long n) {
    long long nn = n, con = 0;
    while (nn) {
        con += nn % 2;
        nn /= 2;
    }
    return n % con;
}

long long f(long long n) {
    long long res = 0;
    while (n) {
        ++res;
        n = p(n);
    }
    return res;
}

int main() {
    int N;
    string S;
    cin >> N >> S;

    long long defa = 0;
    for (auto c : S) defa += c-'0';
    vector<vector<long long>> power(3, vector<long long>(N, 1));
    for (int iter = 0; iter < 3; ++iter) {
        long long base = defa + iter - 1;
        if (base <= 0) continue;
        for (int i = 0; i < N-1; ++i) {
            power[iter][i+1] = power[iter][i] * 2 % base;
        }
    }
    vector<long long> origin(3, 0);
    for (int iter = 0; iter < 3; ++iter) {
        long long base = defa + iter - 1;
        if (base <= 0) continue;
        for (int i = 0; i < N; ++i) {
            if (S[i] == '1') {
                origin[iter] = (origin[iter] + power[iter][N-i-1]) % base;
            }
        }
    }

    for (int i = 0; i < N; ++i) {
        long long base = defa, iter = 1;
        if (S[i] == '0') ++base, ++iter;
        else --base, --iter;
        if (base <= 0) {
            cout << 0 << endl;
            continue;
        }
        long long one = origin[iter];
        if (S[i] == '0') one = (one + power[iter][N-i-1]) % base;
        else one = (one - power[iter][N-i-1] + base) % base;
        cout << f(one)+1 << endl;
    }
}