けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 150 F - Xor Shift (黄色, 600 点)

バチャやった。13 位相当で割とよかった。
ロリハした。

問題へのリンク

問題概要

長さ  N の整数列  a_{0}, a_{1}, \dots, a_{N-1},  b_{0}, b_{1}, \dots, b_{N-1} が与えられる。以下の条件を満たす  0 以上  N-1 以下の整数  k と、整数  x の組をすべて求めよ。

  •  a_{(i + k) \% N} =  b_{i} が成立する

制約

  •  1 \le N \le  2 \times 10^{5}

考えたこと

整数列  a を circular shift した上で、 x と XOR をとると  b に一致させるものをすべて求める問題。まず、 a の circular shift の大きさ  k を決めると

  •  x = a_{k} {\rm XOR} b_{0}

にしかなりえないことがわかる。これが条件を満たさないならダメ。そして

  • 2 つの整数  a,  b の XOR 差分  a {\rm XOR} b
  • それらの  xかました上での XOR 差分  (a {\rm XOR} x) {\rm XOR} (b {\rm XOR} x) = a {\rm XOR} b

は等しいことに着目する。そうしたら、 k を決めたときに条件を満たす  x が存在するためには、 a を circular shift したとき、 a の階差数列と、 b の階差数列とが等しいことが必要かつ十分であることがわかる。よって

  •  a'_{i} = a_{(i+1) \% N} {\rm XOR} a_{i}
  •  b'_{i} = b_{(i+1) \% N} {\rm XOR} b_{i}

として、改めて  a' の circular shift が  b' に一致する部分を全列挙する問題ということになる。

文字列問題へ

 a',  b' を改めて  a,  b とおく。 a を二巡させて長さ  2N の数列 (文字列) と思うことにする。このとき

  •  a \lbrack k : k + N \rbrack = b \lbrack 0 : N \rbrack

となるような  k を列挙すればよいことになる。これはローリングハッシュ ( O(N\log{N})) や KMP 法 ( O(N)) や Z 法 ( O(N)) でできる。 a \lbrack k : \rbrack b との LCP が  N 以上であるかどうかを判定すれば OK。

ロリハによる実装

 #include <iostream>
#include <vector>
using namespace std;

struct RollingHash {
    static const int base1 = 1007, base2 = 2009;
    static const int mod1 = 1000000007, mod2 = 1000000009;
    vector<long long> hash1, hash2, power1, power2;

    // construct
    RollingHash(const vector<long long> &S) {
        int n = (int)S.size();
        hash1.assign(n+1, 0);
        hash2.assign(n+1, 0);
        power1.assign(n+1, 1);
        power2.assign(n+1, 1);
        for (int i = 0; i < n; ++i) {
            hash1[i+1] = (hash1[i] * base1 + S[i]) % mod1;
            hash2[i+1] = (hash2[i] * base2 + S[i]) % mod2;
            power1[i+1] = (power1[i] * base1) % mod1;
            power2[i+1] = (power2[i] * base2) % mod2;
        }
    }
    
    // get hash of S[left:right]
    inline pair<long long, long long> get(int l, int r) const {
        long long res1 = hash1[r] - hash1[l] * power1[r-l] % mod1;
        if (res1 < 0) res1 += mod1;
        long long res2 = hash2[r] - hash2[l] * power2[r-l] % mod2;
        if (res2 < 0) res2 += mod2;
        return {res1, res2};
    }

    // get lcp of S[a:] and S[b:]
    inline int getLCP(int a, int b) const {
        int len = min((int)hash1.size()-a, (int)hash1.size()-b);
        int low = 0, high = len;
        while (high - low > 1) {
            int mid = (low + high) >> 1;
            if (get(a, a+mid) != get(b, b+mid)) high = mid;
            else low = mid;
        }
        return low;
    }

    // get lcp of S[a:] and T[b:]
    inline int getLCP(const RollingHash &T, int a, int b) const {
        int len = min((int)hash1.size()-a, (int)hash1.size()-b);
        int low = 0, high = len;
        while (high - low > 1) {
            int mid = (low + high) >> 1;
            if (get(a, a+mid) != T.get(b, b+mid)) high = mid;
            else low = mid;
        }
        return low;
    }
};


int N;
vector<long long> a, b;

void solve() {
    vector<long long> da(N*2), db(N*2);
    for (int i = 0; i < N*2; ++i) {
        int j = i % N;
        int k = (i + 1) % N;
        da[i] = a[j] ^ a[k];
        db[i] = b[j] ^ b[k];
    }

    RollingHash A(da), B(db);
    for (int k = 0; k < N; ++k) {
        int len = A.getLCP(B, k, 0);
        if (len >= N) {
            cout << k << " " << (a[k] ^ b[0]) << endl;
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    while (cin >> N) {
        a.resize(N); b.resize(N);
        for (int i = 0; i < N; ++i) cin >> a[i];
        for (int i = 0; i < N; ++i) cin >> b[i];
        solve();
    }
}

KMP 法による実装

#include <iostream>
#include <vector>
using namespace std;

struct KMP {
    vector<long long> pat;
    vector<int> fail;

    // construct
    KMP(const vector<long long> &p) { init(p); }
    void init(const vector<long long> &p) {
        pat = p;
        int m = (int)pat.size();
        fail.assign(m+1, -1);
        for (int i = 0, j = -1; i < m; ++i) {
            while (j >= 0 && pat[i] != pat[j]) j = fail[j];
            fail[i+1] = ++j;
        }
    }

    // the period of S[0:i]
    int period(int i) { return i - fail[i]; }
    
    // the index i such that S[i:] has the exact prefix p
    vector<int> match(const vector<long long> &S) {
        int n = (int)S.size(), m = (int)pat.size();
        vector<int> res;
        for (int i = 0, k = 0; i < n; ++i) {
            while (k >= 0 && S[i] != pat[k]) k = fail[k];
            ++k;
            if (k >= m) res.push_back(i - m + 1), k = fail[k];
        }
        return res;
    }
};

int N;
vector<long long> a, b;

void solve() {
    vector<long long> da(N*2), db(N);
    for (int i = 0; i < N*2; ++i) {
        int j = i % N;
        int k = (i + 1) % N;
        da[i] = a[j] ^ a[k];
        if (i < N) db[i] = b[j] ^ b[k];
    }
    KMP kmp(db);
    auto v = kmp.match(da);
    for (auto k : v) {
        if (k < N) cout << k << " " << (a[k] ^ b[0]) << endl;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    while (cin >> N) {
        a.resize(N); b.resize(N);
        for (int i = 0; i < N; ++i) cin >> a[i];
        for (int i = 0; i < N; ++i) cin >> b[i];
        solve();
    }
}

Z 法による実装

Z-algorithm を使うときは、上で求めた  b', a' をこの順に連結してやると良さそう。

#include <iostream>
#include <vector>
using namespace std;

vector<int> Zalgo(const vector<long long> &S) {
    int N = (int)S.size();
    vector<int> res(N);
    res[0] = N;
    int i = 1, j = 0;
    while (i < N) {
        while (i+j < N && S[j] == S[i+j]) ++j;
        res[i] = j;
        if (j == 0) {++i; continue;}
        int k = 1;
        while (i+k < N && k+res[k] < j) res[i+k] = res[k], ++k;
        i += k, j -= k;
    }
    return res;
}

int N;
vector<long long> a, b;

void solve() {
    vector<long long> ab(N * 3);
    for (int i = 0; i < N*3; ++i) {
        int j = i % N;
        int k = (i + 1) % N;
        if (i < N) ab[i] = b[j] ^ b[k];
        else ab[i] = a[j] ^ a[k];
    }

    auto v = Zalgo(ab);
    for (int k = N; k < N*2; ++k) {
        if (v[k] >= N) cout << k-N << " " << (a[k-N] ^ b[0]) << endl;
    }
}

int main() {
    while (cin >> N) {
        a.resize(N); b.resize(N);
        for (int i = 0; i < N; ++i) cin >> a[i];
        for (int i = 0; i < N; ++i) cin >> b[i];
        solve();
    }
}