けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 246 F - typewriter (青色, 500 点)

包除原理を学べる問題!

問題概要

 N 個の文字列  S_{1}, S_{2}, \dots, S_{N} が与えられます。次の手順によって作れる長さ  L の文字列の個数を 998244353 で割ったあまりを求めてください。

  1.  k = 1, 2, \dots, N のいずれかを選ぶ
  2. 文字列  S_{k} に含まれる文字のみを使って、長さ  L の文字列を作る

制約

  •  1 \le N \le 18
  •  1 \le L \le 10^{9}

包除原理へ

たとえば  N = 2 L = 2 で文字列が "ab"、"ac" である場合は

  • "ab" から作れるのは、"aa", "ab", "ba", "bb" の 4 通り
  • "ac" から作れるのは、"aa", "ac", "ca", "cc" の 4 通り
  • 両方から作れるのは、"aa" の 1 通り

ということで、4 + 4 - 1 = 7 通りと求められます。より一般に  N = 2 の場合は、

(S[0] から作れる文字列の個数)
+ (S[1] から作れる文字列の個数)
- (S[0], S[1] の両方から作れる文字列の個数)

を計算すればよいことになります。

なお、S[0] から作れる文字列の個数は、S[0] に含まれる文字の種類数を  m として  m^{L} 通りとなります。

S[0], S[1] から作れる文字列の個数は、S[0] と S[1] にともに含まれる文字の種類数を  m として  m^{L} 通りとなります。

一般の  N の場合

 N = 2 の場合を拡張して、一般の  N の場合は次のように求められます。


文字列  S_{1}, S_{2}, \dots, S_{N} の空でない部分集合  T 2^{N} - 1 通り考えられる。そのそれぞれに対して、

  • 集合  T に含まれる文字列のいずれからも、作れる長さ  L の文字列の個数を  f(T) としたとき
    •  T の要素数が奇数ならば  f(T) を加算
    •  T の要素数が偶数ならば  f(T) を減算

なお、 f(T) の計算は、 T に含まれる文字列のすべてに含まれる文字の個数を  m として、

 f(T) = m^{L}

と計算できます。

コード

計算量は、各部分集合  T に対して

  • すべてに含まれる文字の個数  m の計算: O(N)
  •  m^{L} の計算: O(\log L)

ですので、全体として  O(2^{N}(N + \log L)) となります。

#include <bits/stdc++.h>
using namespace std;

const int MOD = 998244353;
long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}

int main() {
    int N, L;
    cin >> N >> L;
    vector<int> S(N, 0);
    for (int i = 0; i < N; ++i) {
        string str;
        cin >> str;
        for (char c : str) S[i] |= 1 << (c - 'a');
    }

    long long res = 0;
    for (int bit = 1; bit < (1 << N); ++bit) {
        int tmp = (1 << 26) - 1;
        for (int i = 0; i < N; ++i) {
            if (bit & (1 << i)) tmp &= S[i];
        }
        int m = __builtin_popcount(tmp);
        long long mL = modpow(m, L, MOD);

        if (__builtin_popcount(bit) & 1) res = (res + mL) % MOD;
        else res = (res + MOD - mL) % MOD;
    }
    cout << res << endl;
}