けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Codeforces Round #615 (Div. 3) E. Obtain a Permutation (R1900)

こういうのをちゃんと解けるようになったのは成長!

問題へのリンク

問題概要 (意訳)

以下の  M 個のクエリに答えよ。

 i 番目のクエリでは、数列  a_{1}, a_{2}, \dots, a_{N} が与えられる。この数列に以下のいずれかの操作をほどこして、 i, i+N, i+2N, \dots となるようにしたい。その最小回数を求めよ。

  •  a_{i} を任意の値に書き換える
  • 数列を circular shift する。具体的には  a_{2}, a_{3}, \dots, a_{N}, a_{1} にする。

制約

  •  1 \le M, N \le 2 \times 10^{5}
  •  1 \le N \times M \le 2 \times 10^{5}

考えたこと

 N \times M N 10^{5} 程度の大きさなので、各クエリには  O(N) O(N\log{N}) で答える必要がある。 i, i+N, i+2N, \dots b_{1}, b_{2}, \dots, b_{N} とおく。

まず言えることは、操作を行う流れとしては

  • 先に巡回操作を何回か行ってから
  • 数列の各項を書き換えていく

という風にしてよいことがわかる。そこで次のような方針が立つ。

  • 数列の巡回操作を行う回数  k を固定したときに、それによって得られる数列  a と数列  b のハミング距離を求めればよい
  • そのハミング距離 +  k の値の最小値を求めればよい

ここで逆転の発想をする。 a の各要素に対して、「巡回操作の回数が何回であれば、その値を書き換えなくて済むか」を求めていくことにする。たとえば

  • a = (3, 1, 4, 3, 5, 5)
  • b = (1, 2, 3, 4, 5, 6)

のときは、

  • a[0]: 4 回
  • a[1]: 1 回
  • a[2]: 5 回
  • a[3]: 1 回
  • a[4]: 0 回
  • a[5]: 1 回

となる。これによって、巡回操作が 0, 1, 2, 3, 4, 5 回であるようなものの登場個数がそれぞれ、1 個、3 個、0 個、0 個、1 個、1 個となっているので、総合的な操作回数は

  • 巡回操作が 0 回のとき、0 + (6 - 1) = 5 回
  • 巡回操作が 1 回のとき、1 + (6 - 3) = 4 回
  • 巡回操作が 2 回のとき、2 + (6 - 0) = 8 回
  • 巡回操作が 3 回のとき、3 + (6 - 0) = 9 回
  • 巡回操作が 4 回のとき、4 + (6 - 1) = 9 回
  • 巡回操作が 5 回のとき、5 + (6 - 1) = 10 回

ということがわかる。このうちの最小値を求めればよい。

#include <iostream>
#include <vector>
#include <map>
using namespace std;

int calc(const vector<int> &a, map<int, int> &target) {
    int N = a.size();
    vector<int> con(N, 0);
    for (int i = 0; i < N; ++i) {
        if (target.count(a[i])) {
            int need = (N + i - target[a[i]]) % N;
            con[need]++;
        }
    }
    int res = N;
    for (int i = 0; i < N; ++i) {
        int tmp = i + (N - con[i]);
        res = min(res, tmp);
    }
    return res;
}
 
int main() {
    int N, M;
    scanf("%d %d", &N, &M);
    vector<vector<int> > a(M, vector<int>(N));
    for (int i = 0; i < N; ++i) for (int j = 0; j < M; ++j) scanf("%d", &a[j][i]);

    long long res = 0;
    for (int j = 0; j < M; ++j) {
        map<int, int> ma;
        for (int i = 0; i < N; ++i) ma[i*M + j + 1] = i;
        long long tmp = calc(a[j], ma);
        res += tmp;
    }
    printf("%lld\n", res);
}