けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 307 E - Distinct Adjacent (1Q, 水色, 475 点)

共通テスト数学 IA にも似た問題が出ていた!

問題概要

頂点数が  N のサイクルグラフが与えられる。このグラフの各頂点を色  1, 2, \dots, M のいずれかの色で塗る。

どの隣接する頂点対も異なる色で塗られるようにする方法の個数を 998244353 で割った余りを求めよ。

制約

  •  2 \le N, M \le 10^{6}

解法 (1):普通の DP

まずは円環上で DP するときの一般的なテクを活用する。最初に頂点 1 を色で塗ることにしよう。この方法は  M 通りある。その後、頂点 [2, 3, \dots, N] を順に塗っていく。このとき、次の DP をしよう。


  • dp[n][f] ← 頂点  1, 2, \dots, n の色を塗る方法のうち、 i = 1, 2, \dots, n-1 に対して頂点  i, i+1 の色が異なり、さらに、
    •  f = 0 のときは頂点  n の色が頂点  1 と異なり、
    •  f = 1 のときは頂点 1 と同じであるような方法の個数

dp[n] から dp[n+1] への具体的な遷移は次のように考えられる。

  • dp[n][0] から dp[n+1][0] への遷移:
    • この場合は頂点  n, n+1 が同じ色で塗られることになり不適。よって 0 通り
  • dp[n][0] から dp[n+1][1] への遷移:
    • この場合は頂点  n+1 の色を、頂点  0, n とは異なる色にすればよいので  M-1 通りある
    • よって、dp[n+1][1] += dp[n][0] * (M-1) と書ける
  • dp[n][1] から dp[n+1][0] への遷移:
    • この場合は頂点  n+1 を頂点 0 と同じ色に塗るしかないので 1 通り
    • よって、dp[n+1][0] += dp[n][1] と書ける
  • dp[n][1] から dp[n+1][1] への遷移:
    • この場合は頂点  n+1 の色を、頂点  0 とも頂点  n とも異なる色にするので  M-2 通りある
    • よって、dp[n+1][1] += dp[n][1] * (M-2) と書ける

最終的に dp[N][0] を答えれば良い。計算量は  O(N) と評価できる。

コード

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, M;
    cin >> N >> M;
    
    vector dp(N+1, vector(2, mint(0)));
    dp[1][0] = M;
    for (int n = 1; n < N; ++n) {
        dp[n+1][1] += dp[n][0] * (M-1);
        dp[n+1][0] += dp[n][1];
        dp[n+1][1] += dp[n][1] * (M-2);
    }
    cout << dp[N][1] << endl;
}

 

解法 (2):包除原理

包除原理でも解ける。まず、

  • 頂点 1 をある色で塗り ( M 通り)
  • 頂点 2 を頂点 1 と異なる色で塗り ( M-1 通り)
  • 頂点 3 を頂点 2 と異なる色で塗り ( M-1 通り)
  • ...
  • 頂点  N を頂点  N-1 と異なる色で塗り ( M-1 通り)

と繰り返していく方法は  M(M-1)^{N-1} 通りある。

このとき、ほとんどの隣接頂点対は異なる色で塗られるが、唯一、頂点 1 と頂点  N だけは同じ色で塗られる可能性がある。この場合を除外することにしよう。一瞬難しく思えるが、実は難しくない。頂点 1 と頂点  N をひとかたまりにすると、実は「 N-1 個の頂点からなるサイクルグラフを条件が満たすように  M 色に塗り分ける」方法の個数に等しいのだ。

よって、 N 頂点の問題が、 N-1 頂点の問題へ帰着されたのだ。こうなると DP したくなるだろう。


dp[n] ← 頂点数  n のサイクルグラフの各頂点を  M 色で塗り分ける方法であって、どの隣接する頂点も異なる色で塗られる方法の個数


上記の議論によって、次のように更新式が立てられる。

dp[n] =  M(M-1)^{n-1} \times dp[n-1]

この式を用いると計算量  O(N) で解ける。

コード

ここでは、 M(M-1)^{n-1} を計算する部分を愚直に関数 pow() を使った。

#include <bits/stdc++.h>
using namespace std;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, M;
    cin >> N >> M;
    
    vector<mint> dp(N + 1, 0);
    dp[2] = mint(M) * (M - 1);
    for (int n = 3; n <= N; ++n) {
        dp[n] = mint(M) * mint(M-1).pow(n-1) - dp[n-1];
    }
    cout << dp[N] << endl;
}