行列累乗した。デバッグに手こずった。
問題概要
の順列 が与えられる。以下の操作を 回行う。
- を選んで と を swap する
操作列は 通り考えられるが、それぞれについての の総和を 998244353 で割った余りを求めよ。
制約
考えたこと
の期待値を求めて、最後に をかけることにした。
このとき、期待値の線形性から、操作後の の値の期待値を求めて、各 について合算すれば十分である。ここで、値の組 が操作によってどのように変化していくかを考えると、以下の 4 パターンを考えれば十分であることに気づく。なお、初期状態で , であるとする。
- が である状態 (順序は問わない)
- が と 以外である状態 (順序は問わない)
- が と 以外である状態 (順序は問わない)
- がいずれも でも でもない状態 (順序は問わない)
各グループについて 通りあるわけだが、同じグループ内では確率は等しいため、この 4 通りの状態に潰してよいわけだ。
この 4 通りの状態を移り変わる推移図を求めて、行列累乗することによって、 回操作後の、上記 4 状態になっている確率を求めることができる。これを用いて、操作後の の値の期待値が求められる。
なお、この求めた確率は の値によらないので、他の にも再利用できる。
全体として計算量は となる。
コード
#include <bits/stdc++.h> using namespace std; // matrix template<class mint> struct MintMatrix { // inner value vector<vector<mint>> val; // constructors MintMatrix(int H, int W, mint x = 0) : val(H, vector<mint>(W, x)) {} MintMatrix(const MintMatrix &mat) : val(mat.val) {} void init(int H, int W, mint x = 0) { val.assign(H, vector<mint>(W, x)); } void resize(int H, int W) { val.resize(H); for (int i = 0; i < H; ++i) val[i].resize(W); } // getter and debugger constexpr int height() const { return (int)val.size(); } constexpr int width() const { return (int)val[0].size(); } vector<mint>& operator [] (int i) { return val[i]; } constexpr vector<mint>& operator [] (int i) const { return val[i]; } friend constexpr ostream& operator << (ostream &os, const MintMatrix<mint> &mat) { os << endl; for (int i = 0; i < mat.height(); ++i) { for (int j = 0; j < mat.width(); ++j) { if (j) os << ", "; os << mat.val[i][j]; } os << endl; } return os; } // comparison operators constexpr bool operator == (const MintMatrix &r) const { return this->val == r.val; } constexpr bool operator != (const MintMatrix &r) const { return this->val != r.val; } // arithmetic operators constexpr MintMatrix& operator += (const MintMatrix &r) { assert(height() == r.height()); assert(width() == r.width()); for (int i = 0; i < height(); ++i) { for (int j = 0; j < width(); ++j) { val[i][j] += r.val[i][j]; } } return *this; } constexpr MintMatrix& operator -= (const MintMatrix &r) { assert(height() == r.height()); assert(width() == r.width()); for (int i = 0; i < height(); ++i) { for (int j = 0; j < width(); ++j) { val[i][j] -= r.val[i][j]; } } return *this; } constexpr MintMatrix& operator *= (const mint &v) { for (int i = 0; i < height(); ++i) for (int j = 0; j < width(); ++j) val[i][j] *= v; return *this; } constexpr MintMatrix& operator *= (const MintMatrix &r) { assert(width() == r.height()); MintMatrix<mint> res(height(), r.width()); for (int i = 0; i < height(); ++i) for (int j = 0; j < r.width(); ++j) for (int k = 0; k < width(); ++k) res[i][j] += val[i][k] * r.val[k][j]; return (*this) = res; } constexpr MintMatrix operator + () const { return MintMatrix(*this); } constexpr MintMatrix operator - () const { return MintMatrix(*this) *= mint(-1); } constexpr MintMatrix operator + (const MintMatrix &r) const { return MintMatrix(*this) += r; } constexpr MintMatrix operator - (const MintMatrix &r) const { return MintMatrix(*this) -= r; } constexpr MintMatrix operator * (const mint &v) const { return MintMatrix(*this) *= v; } constexpr MintMatrix operator * (const MintMatrix &r) const { return MintMatrix(*this) *= r; } // pow constexpr MintMatrix pow(long long n) const { assert(height() == width()); MintMatrix<mint> res(height(), width()), mul(*this); for (int row = 0; row < height(); ++row) res[row][row] = 1; while (n > 0) { if (n & 1) res *= mul; mul *= mul; n >>= 1; } return res; } friend constexpr MintMatrix<mint> pow(const MintMatrix<mint> &mat, long long n) { return mat.pow(n); } // gauss-jordan constexpr int find_pivot(int cur_rank, int col) const { int pivot = -1; for (int row = cur_rank; row < height(); ++row) { if (val[row][col] != 0) { pivot = row; break; } } return pivot; } constexpr void sweep(int cur_rank, int col, int pivot) { swap(val[pivot], val[cur_rank]); auto ifac = val[cur_rank][col].inv(); for (int col2 = 0; col2 < width(); ++col2) { val[cur_rank][col2] *= ifac; } for (int row = 0; row < height(); ++row) { if (row != cur_rank && val[row][col] != 0) { auto fac = val[row][col]; for (int col2 = 0; col2 < width(); ++col2) { val[row][col2] -= val[cur_rank][col2] * fac; } } } } constexpr int gauss_jordan(int not_sweep_width = 0) { int rank = 0; for (int col = 0; col < width(); ++col) { if (col == width() - not_sweep_width) break; int pivot = find_pivot(rank, col); if (pivot == -1) continue; sweep(rank++, col, pivot); } return rank; } friend constexpr int gauss_jordan(MintMatrix<mint> &mat, int not_sweep_width = 0) { return mat.gauss_jordan(not_sweep_width); } friend constexpr int linear_equation (const MintMatrix<mint> &mat, const vector<mint> &b, vector<mint> &res) { // extend MintMatrix<mint> A(mat.height(), mat.width() + 1); for (int i = 0; i < mat.height(); ++i) { for (int j = 0; j < mat.width(); ++j) A[i][j] = mat.val[i][j]; A[i].back() = b[i]; } int rank = A.gauss_jordan(1); // check if it has no solution for (int row = rank; row < mat.height(); ++row) if (A[row].back() != 0) return -1; // answer res.assign(mat.width(), 0); for (int i = 0; i < rank; ++i) res[i] = A[i].back(); return rank; } friend constexpr int linear_equation(const MintMatrix<mint> &mat, const vector<mint> &b) { vector<mint> res; return linear_equation(mat, b, res); } // determinant constexpr mint det() const { MintMatrix<mint> A(*this); int rank = 0; mint res = 1; for (int col = 0; col < width(); ++col) { int pivot = A.find_pivot(rank, col); if (pivot == -1) return mint(0); res *= A[pivot][rank]; A.sweep(rank++, col, pivot); } return res; } friend constexpr mint det(const MintMatrix<mint> &mat) { return mat.det(); } }; // modint template<int MOD> struct Fp { // inner value long long val; // constructor constexpr Fp() : val(0) { } constexpr Fp(long long v) : val(v % MOD) { if (val < 0) val += MOD; } constexpr Fp(const Fp &v) : val(v.get()) { } constexpr long long get() const { return val; } constexpr int get_mod() const { return MOD; } // arithmetic operators constexpr Fp operator + () const { return Fp(*this); } constexpr Fp operator - () const { return Fp(0) - Fp(*this); } constexpr Fp operator + (const Fp &r) const { return Fp(*this) += r; } constexpr Fp operator - (const Fp &r) const { return Fp(*this) -= r; } constexpr Fp operator * (const Fp &r) const { return Fp(*this) *= r; } constexpr Fp operator / (const Fp &r) const { return Fp(*this) /= r; } constexpr Fp& operator += (const Fp &r) { val += r.val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -= (const Fp &r) { val -= r.val; if (val < 0) val += MOD; return *this; } constexpr Fp& operator *= (const Fp &r) { val = val * r.val % MOD; return *this; } constexpr Fp& operator /= (const Fp &r) { long long a = r.val, b = MOD, u = 1, v = 0; while (b) { long long t = a / b; a -= t * b, swap(a, b); u -= t * v, swap(u, v); } val = val * u % MOD; if (val < 0) val += MOD; return *this; } constexpr Fp pow(long long n) const { Fp res(1), mul(*this); while (n > 0) { if (n & 1) res *= mul; mul *= mul; n >>= 1; } return res; } constexpr Fp inv() const { Fp res(1), div(*this); return res / div; } // other operators constexpr bool operator == (const Fp &r) const { return this->val == r.val; } constexpr bool operator != (const Fp &r) const { return this->val != r.val; } constexpr Fp& operator ++ () { ++val; if (val >= MOD) val -= MOD; return *this; } constexpr Fp& operator -- () { if (val == 0) val += MOD; --val; return *this; } constexpr Fp operator ++ (int) const { Fp res = *this; ++*this; return res; } constexpr Fp operator -- (int) const { Fp res = *this; --*this; return res; } friend constexpr istream& operator >> (istream &is, Fp<MOD> &x) { is >> x.val; x.val %= MOD; if (x.val < 0) x.val += MOD; return is; } friend constexpr ostream& operator << (ostream &os, const Fp<MOD> &x) { return os << x.val; } friend constexpr Fp<MOD> pow(const Fp<MOD> &r, long long n) { return r.pow(n); } friend constexpr Fp<MOD> inv(const Fp<MOD> &r) { return r.inv(); } }; int main() { const int MOD = 998244353; using mint = Fp<MOD>; long long N, M; cin >> N >> M; vector<long long> P(N); for (int i = 0; i < N; ++i) cin >> P[i], --P[i]; long long NC = N * (N - 1) / 2, N2C = (N - 2) * (N - 3) / 2; mint all = mint(NC).pow(M); if (N == 2) { cout << all << endl; return 0; } auto Q = P; sort(Q.begin(), Q.end()); vector<long long> left(N+1, 0), right(N+1, 0); for (int i = 0; i < N; ++i) { left[i+1] = left[i] + Q[i]; right[i+1] = right[i] + Q[N-i-1]; } auto calc_sum = [&](long long x) -> long long { long long l = lower_bound(Q.begin(), Q.end(), x) - Q.begin(); return (x * l - left[l]) + (right[N - l] - x * (N - l)); }; vector<mint> f(N, 0); mint S = 0; for (int i = 0; i < N; ++i) { f[i] = calc_sum(P[i]); S += f[i]; } MintMatrix<mint> A(4, 4); A[0][0] = mint(N2C + 1); // / NC; A[0][1] = A[0][2] = A[1][2] = A[2][1] = mint(1); // / NC; A[1][0] = A[2][0] = mint(N - 2); // / NC; A[1][1] = A[2][2] = mint(N2C + N - 2); // / NC; A[1][3] = A[2][3] = mint(2); // / NC; A[3][1] = A[3][2] = mint(N - 3); // / NC; A[3][3] = mint(N2C + N * 2 - 7); // / NC; auto AM = pow(A, M); mint res = 0; for (int i = 0; i + 1 < N; ++i) { mint diff = abs(P[i] - P[i+1]); res += AM[0][0] * diff; res += AM[1][0] * (f[i] - diff) / mint(N - 2); res += AM[2][0] * (f[i+1] - diff) / mint(N - 2); if (N > 3) res += AM[3][0] * (mint(S) / 2 - f[i] - f[i+1] + diff) / mint(N2C); } cout << res << endl; }