けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

Educational Codeforces 62 E - Palindrome-less Arrays

楽しかった

問題へのリンク

問題概要

未完成の数列  a_1, a_2, \dots, a_N が与えられる。未完成部分には  -1 が書かれている。完成部分には  1 以上  K 以下の整数が書かれている。

 -1 のところを  1 以上  K 以下の整数を埋める方法であって、整数列が奇数長の回文を含まないものは何通りあるか?

制約

  •  2 \le N, K \le 2 × 10^{5}

考えたこと

奇数長の回文を含まないことと、任意の  i に対して  a_{i} \neq a_{i+2} が成り立つことは同値である。よって、数列の偶数番目と奇数番目それぞれについて

  • 隣接要素が異なるように -1 を埋める方法

を数え上げて掛け算すればよい。この方法は

  • 両端が異なる数字で -1 が len 個連続する区間
  • 両端が同じ数字で -1 が len 個連続する区間
  • 片方が開いていて -1 が len 個連続する区間
  • 両端が開いていて -1 が len 個連続する区間

とをそれぞれ DP で前処理する方針で解いた。

#include <iostream>
#include <vector>
using namespace std;

const long long MOD = 998244353;

int main() {
    int N, K; cin >> N >> K;
    vector<long long> a(N);
    for (int i = 0; i < N; ++i) cin >> a[i];
    
    vector<vector<long long> > dp(N+1, vector<long long>(4, 0));
    dp[0][0] = 1;
    dp[0][2] = 1;
    dp[0][3] = 1;
    for (int len = 1; len <= N; ++len) {
        dp[len][0] = (dp[len-1][1] + dp[len-1][0] * (K-2) % MOD) % MOD;
        dp[len][1] = dp[len-1][0] * (K-1) % MOD;
        dp[len][2] = (dp[len-1][2] * (K-1)) % MOD;
        dp[len][3] = (dp[len-1][2] * K) % MOD;
    }
    long long res = 1;
    
    // 偶数
    int prev = -1;
    int len = 0;
    for (int i = 0; i < N; i += 2) {
        if (a[i] == -1) ++len;
        else {
            if (prev == -1) res = (res * dp[len][2]) % MOD;
                else if (a[i] == prev) res = (res * dp[len][1]) % MOD;
                else res = (res * dp[len][0]) % MOD;
            len = 0;
            prev = a[i];
        }
        }
    if (len) {
        if (prev == -1) res = (res * dp[len][3]) % MOD;
        else res = (res * dp[len][2]) % MOD;
    }
    
    // 奇数
    prev = -1;
    len = 0;
    for (int i = 1; i < N; i += 2) {
        if (a[i] == -1) ++len;
        else {
            if (prev == -1) res = (res * dp[len][2]) % MOD;
            else if (a[i] == prev) res = (res * dp[len][1]) % MOD;
                else res = (res * dp[len][0]) % MOD;
            len = 0;
            prev = a[i];
        }
    }
    if (len) {
        if (prev == -1) res = (res * dp[len][3]) % MOD;
            else res = (res * dp[len][2]) % MOD;
    }
    
    cout << res << endl;   
}