けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

AtCoder ABC 176 C - Max Permutation (黄色, 700 点)

これは面白かった。

問題概要

 1, 2, \dots, N の順列であって、以下の  M 個の条件を満たすものの個数を 998244353 で割った余りを求めよ。

  •  j = 1, 2, \dots, M について、 \max(P_{A_{j}}, P_{B_{j}}) = C_{j} である

制約

  •  N, M \le 2 \times 10^{5}

考えたこと

グラフの問題として考えることにした。つまり、0-indexed で表現すると


頂点数   N、辺数  M であるグラフが与えられる。各辺には  0 以上  N-1 以下の重みがついている。

各頂点に  0, 1, \dots, N-1 の値を割り振る方法であって、どの辺についても両端の頂点の値の最大値が辺の重みに一致する方法の個数を求めよ。


重みが大きい方から順に考えることとした。次のように考えることができた。

また、常に「未使用の孤立頂点の個数」を表す値  s を管理しておくこととする。初期状態では

  • res = 1
  • s = (初期グラフの孤立点の個数)

としておく。


 x = N-1, N-2, \dots, 0 について順に考えて

  • 重み  x の辺が存在しないとき:
    • もし s = 0 ならば、答えは問答無用で 0 通りである
    • そうでないならば、res *= s, --s とする
  • 重み  x の辺がちょうど 1 本存在するとき:
    • その時点でのグラフについて、もしその辺の両端の次数がともに 2 以上ならば、答えは問答無用で 0 通りである
    • そうでないとき、重み  x の頂点を以下のように決定し、その頂点を削除する (このとき、孤立点が新たに誕生するならば ++s とする)
      • 片方の頂点の次数が 1、他方の頂点の次数が 2 以上のとき、次数 1 の頂点の値を  x にする
      • 両端の頂点の次数が 1 のとき、いずれかの頂点の値を  x とする (`res *= 2' とする)
  • 重み  v の辺が 2 本以上存在するとき:
    • それらの辺がスターグラフを形成していないならば、問答無用で 0 通りである
    • そうでないとき、スターグラフの中心の頂点を  v として、以下のようにする
      • 頂点  v に重み  v 未満の辺が接続しているならば、問答無用で 0 通りである
      • そうでないとき、頂点  v の重みを  x の頂点として、その頂点  v を削除する (このとき、孤立点が新たに誕生するならば ++s とする)

こうして重みが  v = N-1, N-2, \dots, 0 の頂点を順に決定していく。計算量は  O(N) となる。

コード

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;

// modint
template<int MOD> struct Fp {
    // inner value
    long long val;
    
    // constructor
    constexpr Fp() : val(0) { }
    constexpr Fp(long long v) : val(v % MOD) {
        if (val < 0) val += MOD;
    }
    constexpr long long get() const { return val; }
    constexpr int get_mod() const { return MOD; }
    
    // arithmetic operators
    constexpr Fp operator + () const { return Fp(*this); }
    constexpr Fp operator - () const { return Fp(0) - Fp(*this); }
    constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; }
    constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; }
    constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; }
    constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; }
    constexpr Fp& operator += (const Fp &r) {
        val += r.val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -= (const Fp &r) {
        val -= r.val;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp& operator *= (const Fp &r) {
        val = val * r.val % MOD;
        return *this;
    }
    constexpr Fp& operator /= (const Fp &r) {
        long long a = r.val, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b, swap(a, b);
            u -= t * v, swap(u, v);
        }
        val = val * u % MOD;
        if (val < 0) val += MOD;
        return *this;
    }
    constexpr Fp pow(long long n) const {
        Fp res(1), mul(*this);
        while (n > 0) {
            if (n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }
    constexpr Fp inv() const {
        Fp res(1), div(*this);
        return res / div;
    }

    // other operators
    constexpr bool operator == (const Fp &r) const {
        return this->val == r.val;
    }
    constexpr bool operator != (const Fp &r) const {
        return this->val != r.val;
    }
    constexpr Fp& operator ++ () {
        ++val;
        if (val >= MOD) val -= MOD;
        return *this;
    }
    constexpr Fp& operator -- () {
        if (val == 0) val += MOD;
        --val;
        return *this;
    }
    constexpr Fp operator ++ (int) const {
        Fp res = *this;
        ++*this;
        return res;
    }
    constexpr Fp operator -- (int) const {
        Fp res = *this;
        --*this;
        return res;
    }
    friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) {
        is >> x.val;
        x.val %= MOD;
        if (x.val < 0) x.val += MOD;
        return is;
    }
    friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) {
        return os << x.val;
    }
    friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) {
        return r.pow(n);
    }
    friend constexpr Fp<MOD> inv(const Fp<MOD> &r) {
        return r.inv();
    }
};

const int MOD = 998244353;
using mint = Fp<MOD>;

int main() {
    int N, M;
    cin >> N >> M;
    
    // グラフを求める
    vector<vector<pint>> G(N);
    vector<int> deg(N, 0);
    vector<vector<pint>> edges(N);
    for (int i = 0; i < M; ++i) {
        int a, b, x;
        cin >> a >> b >> x;
        --a, --b, --x;
        G[a].emplace_back(b, x), G[b].emplace_back(a, x);
        ++deg[a], ++deg[b];
        edges[x].emplace_back(a, b);
    }
    
    // 孤立点の個数を求める
    int solitude = 0;
    for (int i = 0; i < N; ++i) {
        if (deg[i] == 0) ++solitude;
    }
    
    // 頂点 v を削除する関数
    auto pop = [&](int v) -> void {
        deg[v] = 0;
        for (auto e : G[v]) {
            --deg[e.first];
            if (deg[e.first] == 0) {
                ++solitude;
            }
        }
    };
    
    auto solve = [&]() -> mint {
        mint res = 1;
        
        for (int x = N-1; x >= 0; --x) {
            if (edges[x].empty()) {
                if (solitude == 0) return mint(0);
                res *= solitude;
                --solitude;
            } else {
                const auto &vec = edges[x];
                if (vec.size() == 1) {
                    int a = vec[0].first, b = vec[0].second;
                    if (deg[a] == 1 && deg[b] == 1) {
                        res *= 2;
                        pop(a);
                    } else if (deg[a] == 1) {
                        pop(a);
                    } else if (deg[b] == 1) {
                        pop(b);
                    } else {
                        return mint(0);
                    }
                } else {
                    map<int, int> ma;
                    for (auto [a, b] : vec) {
                        ++ma[a];
                        ++ma[b];
                    }
                    
                    int node = -1;
                    for (auto [val, num] : ma) {
                        if (num != 1) {
                            if (node != -1) return mint(0);
                            node = val;
                        }
                    }
                    
                    if (deg[node] != vec.size()) return mint(0);
                    pop(node);
                }
            }
        }
        return res;
    };
    cout << solve() << endl;
}