けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

CS Academy 081 DIV2 C - All Numbers

解けたけどもっとちゃんと整理しないとなん

問題へのリンク

問題概要

N 個の整数 a_1, a_2, ..., a_N が与えられる (0 <= a_i <= 9)。これを並び替えて順につないでできる整数 (leading zero は除く) として考えられるものの総和を 109 + 7 で割った余りで求めよ。

制約

  • 1 <= N <= 50

解法

「各桁ごとに見る」という典型解法がピタリとはまる問題ではあるが、leading zero がちょっとイヤな感じ。

leading zero がなければ、対称性から各桁ごとに足し算した値が一致する。そしてその値は

  • 1 × (数列から 1 を除いたものの個数)
  • 2 × (数列から 2 を除いたものの個数)
  • 9 × (数列から 9 を除いたものの個数)

を合計したものとなる。

leading zero がある場合には、「最高位には 1〜9 しか来ない」という風に考えてあげればよい。

#include <iostream>
#include <vector>
using namespace std;

const int MAX = 210000;
const int MOD = 1000000007;

long long fac[MAX],finv[MAX],inv[MAX];
void COMinit(){
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for(int i = 2; i < MAX; i++){
        fac[i] = fac[i-1] * i % MOD;
        inv[i] = MOD - inv[MOD%i] * (MOD/i) % MOD;
        finv[i] = finv[i-1] * inv[i] % MOD;
    }
}

inline long long mod(long long a, long long m) { return (a % m + m) % m; }

long long modinv(long long a, long long m) {
    long long b = m, u = 1, v = 0;
    while (b) {
        long long t = a/b;
        a -= t*b; swap(a, b);
        u -= t*v; swap(u, v);
    }
    return mod(u, m);
}

// 0, 1, ..., 9 の個数が nums で与えられている場合の順列の個数
long long calc(vector<int> nums) {
    long long all = 0;
    long long res = 1;
    for (int i = 0; i < 10; ++i) {
        all += nums[i];
        res = res * finv[nums[i]] % MOD;
    }
    res = res * fac[all] % MOD;
    return res;
}

int main() {
    COMinit();
    int N; cin >> N;
    vector<int> nums(10, 0);
    for (int i = 0; i < N; ++i) {
      int v; cin >> v; nums[v]++;
    }

    // 最高位の総和
    long long final = 0;
    for (int d = 1; d < 10; ++d) {
      if (nums[d] == 0) continue;
      nums[d]--;
      long long p = calc(nums);
      nums[d]++;
      final = (final + p * d % MOD) % MOD;
    }

    // それ以外の総和
    long long other = 0;
    for (int d = 1; d < 10; ++d) {
      if (nums[d] == 0) continue;
      nums[d]--;
      long long p = 0;
      // 最高位を決め打ち
      for (int e = 1; e < 10; ++e) {
        if (nums[e] == 0) continue;
        nums[e]--;
        long long q = calc(nums);
        p = (p + q) % MOD;
        nums[e]++;
      }
      nums[d]++;
      other = (other + p * d % MOD) % MOD;
    }
    
    long long res = 0;
    for (int i = 0; i < N; ++i) {
      res = (res * 10) % MOD;
      if (i != 0) res = (res + other) % MOD;
      else res = (res + final) % MOD;
    }
    cout << res << endl;
}