けんちょんの競プロ精進記録

競プロの精進記録や小ネタを書いていきます

TopCoder SRM 401 DIV1 Hard - NCool (本番 5 人)

最近、ABC 201〜300 の D 問題埋めを推奨している身としては、僕も同様のトレーニングとして SRM 401〜650 辺りの DIV1 Hard 埋めを始めてみようと思い立った。

問題概要

二次元平面上において、以下の条件を満たす線分の cool 度は  n であるという。

  • 両端点が格子点である
  • 線分上の両端および内部にちょうど  n 個の格子点が含まれる

ここで、凸多角形が与えられる。凸多角形に内包される、cool 度が  n 以上である線分をすべて考えて、その線分の両端をなす格子点に色を塗っていく。色を塗られる格子点の個数を求めよ。

(この図は  n = 6 の場合)

制約

  • 多角形の頂点数は  3 以上  50 以下
  • 多角形の各頂点の座標は  0 以上  10000 以下
  •  2 \le n \le 500000

考えたこと

この問題では cool 度が  n 以上である線分をすべて考えてよいとあるが、実際には cool 度がちょうど  n である線分のみを考えれば十分。その方が数えやすそうだ。

ここで、cool 度が  n であるという条件を言い換えていくことを考える。これはもう有名問題で結論が出ていて、線分の cool 度が  n であるとは、

「線分の両端点の x 座標の差と y 座標の差の最大公約数が  n-1

となる。ということは、cool 度が  n である線分の両端点を  (a, b), (c, d) とすると

  •  a \equiv c \pmod{n-1}
  •  b \equiv d \pmod{n-1}

が成り立つ。これは必要条件であって十分条件でないが、とても扱いやすそうな条件なので、この条件が成り立つような線分の両端点に色を塗っていけば十分だと言えたら嬉しい!!!

 a \equiv c \pmod{n-1},  b \equiv d \pmod{n-1} が成り立つとき、線分の cool 度は  k(n-1) + 1 ( k は正の整数) と表せることは言える。もともと問題は「cool 度が  n 以上である線分」を考えていたくらいなので、「cool 度が  k(n-1) + 1 と表せる線分」をすべて考えることにしても全然 OK!!!

帰着した問題

以上の考察から、次の問題へと帰着された。


各格子点  (x, y) を、 (x を n で割った余り, y を n で割った余り) に基づいて、 n^{2} 通りに分類する。

与えられた凸多角形内部に、同じグループに分類される格子点が複数あるならば、それらの格子点に色を塗ることにする。

色が塗られた格子点の個数を求めよ。


ここまで来れば、あとはなんとでもなると思われた。まず、x 座標や y 座標の値としてありうる値の最大値を  M (= 10000) とする。 N \gt M + 1 の場合は 0 と返してよい。

よって、これ以降は  N \le M+1 と仮定できる。

 x 座標を  n で割った余り  r を固定して考える。 x = r, r + n, r + 2n, \dots に対して、その  x 座標をもち、凸多角形の内部にあるような格子点の  y 座標の最小値と最大値を求める方針をとった。

適宜、いもす法なども活用することで、次の配列が  O(N^{2}) の計算量で求められることがわかった。


num[i][j] ← 凸多角形に内包される、 x 座標と  y 座標を  N で割った余りがそれぞれ  i, j であるような格子点の個数


全体の計算量は  O(\min(N, M)^{2}) と評価できる。

コード

moj プラグインに基づくテストも含む。

#include <bits/stdc++.h>
using namespace std;
using pint = pair<int, int>;
using pll = pair<long long, long long>;
template<class T> inline bool chmax(T& a, T b) { if (a < b) { a = b; return 1; } return 0; }
template<class T> inline bool chmin(T& a, T b) { if (a > b) { a = b; return 1; } return 0; }

#define REP(i, n) for (long long i = 0; i < (long long)(n); ++i)
#define REP2(i, a, b) for (long long i = a; i < (long long)(b); ++i)
#define COUT(x) cout << #x << " = " << (x) << " (L" << __LINE__ << ")" << endl
template<class T1, class T2> ostream& operator << (ostream &s, pair<T1,T2> P)
{ return s << '<' << P.first << ", " << P.second << '>'; }
template<class T> ostream& operator << (ostream &s, vector<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, deque<T> P)
{ for (int i = 0; i < P.size(); ++i) { if (i > 0) { s << " "; } s << P[i]; } return s; }
template<class T> ostream& operator << (ostream &s, vector<vector<T> > P)
{ for (int i = 0; i < P.size(); ++i) { s << endl << P[i]; } return s << endl; }
template<class T> ostream& operator << (ostream &s, set<T> P)
{ for(auto it : P) { s << "<" << it << "> "; } return s; }
template<class T> ostream& operator << (ostream &s, multiset<T> P)
{ for(auto it : P) { s << "<" << it << "> "; } return s; }
template<class T1, class T2> ostream& operator << (ostream &s, map<T1,T2> P)
{ for(auto it : P) { s << "<" << it.first << "->" << it.second << "> "; } return s; }


/*/////////////////////////////*/
// 幾何ライブラリ
/*/////////////////////////////*/

// basic settings
using DD = long double;
constexpr long double PI = 3.141592653589793238462643383279502884L;
constexpr long double INF = 1LL<<60;  // to be set appropriately
constexpr long double EPS = 1e-10;    // to be set appropriately
long double torad(int deg) {return (long double)(deg) * PI / 180;}
long double todeg(long double ang) {return ang * 180 / PI;}

// Point or Vector
struct Point {
    DD x, y;
    
    // constructor
    constexpr Point() : x(0), y(0) {}
    constexpr Point(DD x, DD y) : x(x), y(y) {}
    
    // various functions
    constexpr Point conj() const {return Point(x, -y);}
    constexpr DD dot(const Point &r) const {return x * r.x + y * r.y;}
    constexpr DD cross(const Point &r) const {return x * r.y - y * r.x;}
    constexpr DD norm() const {return dot(*this);}
    constexpr long double abs() const {return sqrt(norm());}
    constexpr long double amp() const {
        long double res = atan2(y, x);
        if (res < 0) res += PI*2;
        return res;
    }
    constexpr bool eq(const Point &r) const {return (*this - r).abs() <= EPS;}
    constexpr Point rot90() const {return Point(-y, x);}
    constexpr Point rot(long double ang) const {
        return Point(cos(ang) * x - sin(ang) * y, sin(ang) * x + cos(ang) * y);
    }
    
    // arithmetic operators
    constexpr Point operator - () const {return Point(-x, -y);}
    constexpr Point operator + (const Point &r) const {return Point(*this) += r;}
    constexpr Point operator - (const Point &r) const {return Point(*this) -= r;}
    constexpr Point operator * (const Point &r) const {return Point(*this) *= r;}
    constexpr Point operator / (const Point &r) const {return Point(*this) /= r;}
    constexpr Point operator * (DD r) const {return Point(*this) *= r;}
    constexpr Point operator / (DD r) const {return Point(*this) /= r;}
    constexpr Point& operator += (const Point &r) {
        x += r.x, y += r.y;
        return *this;
    }
    constexpr Point& operator -= (const Point &r) {
        x -= r.x, y -= r.y;
        return *this;
    }
    constexpr Point& operator *= (const Point &r) {
        DD tx = x, ty = y;
        x = tx * r.x - ty * r.y;
        y = tx * r.y + ty * r.x;
        return *this;
    }
    constexpr Point& operator /= (const Point &r) {
        return *this *= r.conj() / r.norm();
    }
    constexpr Point& operator *= (DD r) {
        x *= r, y *= r;
        return *this;
    }
    constexpr Point& operator /= (DD r) {
        x /= r, y /= r;
        return *this;
    }

    // friend functions
    friend ostream& operator << (ostream &s, const Point &p) {
        return s << '(' << p.x << ", " << p.y << ')';
    }
    friend constexpr Point conj(const Point &p) {return p.conj();}
    friend constexpr DD dot(const Point &p, const Point &q) {return p.dot(q);}
    friend constexpr DD cross(const Point &p, const Point &q) {return p.cross(q);}
    friend constexpr DD norm(const Point &p) {return p.norm();}
    friend constexpr long double abs(const Point &p) {return p.abs();}
    friend constexpr long double amp(const Point &p) {return p.amp();}
    friend constexpr bool eq(const Point &p, const Point &q) {return p.eq(q);}
    friend constexpr Point rot90(const Point &p) {return p.rot90();}
    friend constexpr Point rot(const Point &p, long long ang) {return p.rot(ang);}
};

// necessary for some functions
constexpr bool operator < (const Point &p, const Point &q) {
    return (abs(p.x - q.x) > EPS ? p.x < q.x : p.y < q.y);
}

// Line
struct Line : vector<Point> {
    Line(Point a = Point(0.0, 0.0), Point b = Point(0.0, 0.0)) {
        this->push_back(a);
        this->push_back(b);
    }
    friend ostream& operator << (ostream &s, const Line &l) {
        return s << '{' << l[0] << ", " << l[1] << '}';
    }
};

int ccw_for_dis(const Point &a, const Point &b, const Point &c) {
    if (cross(b-a, c-a) > EPS) return 1;
    if (cross(b-a, c-a) < -EPS) return -1;
    if (dot(b-a, c-a) < -EPS) return 2;
    if (norm(b-a) < norm(c-a) - EPS) return -2;
    return 0;
}
Point proj(const Point &p, const Line &l) {
    DD t = dot(p - l[0], l[1] - l[0]) / norm(l[1] - l[0]);
    return l[0] + (l[1] - l[0]) * t;
}
Point refl(const Point &p, const Line &l) {
    return p + (proj(p, l) - p) * 2;
}
bool is_inter_PL(const Point &p, const Line &l) {
    return (abs(p - proj(p, l)) < EPS);
}
bool is_inter_PS(const Point &p, const Line &s) {
    return (ccw_for_dis(s[0], s[1], p) == 0);
}
bool is_inter_LL(const Line &l, const Line &m) {
    return (abs(cross(l[1] - l[0], m[1] - m[0])) > EPS ||
            abs(cross(l[1] - l[0], m[0] - l[0])) < EPS);
}
bool is_inter_SS(const Line &s, const Line &t) {
    if (eq(s[0], s[1])) return is_inter_PS(s[0], t);
    if (eq(t[0], t[1])) return is_inter_PS(t[0], s);
    return (ccw_for_dis(s[0], s[1], t[0]) * ccw_for_dis(s[0], s[1], t[1]) <= 0 &&
            ccw_for_dis(t[0], t[1], s[0]) * ccw_for_dis(t[0], t[1], s[1]) <= 0);
}
DD distance_PL(const Point &p, const Line &l) {
    return abs(p - proj(p, l));
}
DD distance_PS(const Point &p, const Line &s) {
    Point h = proj(p, s);
    if (is_inter_PS(h, s)) return abs(p - h);
    return min(abs(p - s[0]), abs(p - s[1]));
}
DD distance_LL(const Line &l, const Line &m) {
    if (is_inter_LL(l, m)) return 0;
    else return distance_PL(m[0], l);
}
DD distance_SS(const Line &s, const Line &t) {
    if (is_inter_SS(s, t)) return 0;
    else return min(min(distance_PS(s[0], t), distance_PS(s[1], t)),
                    min(distance_PS(t[0], s), distance_PS(t[1], s)));
}

Point proj_for_crosspoint(const Point &p, const Line &l) {
    DD t = dot(p - l[0], l[1] - l[0]) / norm(l[1] - l[0]);
    return l[0] + (l[1] - l[0]) * t;
}
vector<Point> crosspoint(const Line &l, const Line &m) {
    vector<Point> res;
    DD d = cross(m[1] - m[0], l[1] - l[0]);
    if (abs(d) < EPS) return vector<Point>();
    res.push_back(l[0] + (l[1] - l[0]) * cross(m[1] - m[0], m[1] - l[0]) / d);
    return res;
}
vector<Point> crosspoint_SS(const Line &l, const Line &m) {
    if (is_inter_SS(l, m)) return crosspoint(l, m);
    else return vector<Point>();
}

// 凸包 (一直線上の3点を含めない)
vector<Point> convex_hull(vector<Point> &ps) {
    int n = (int)ps.size();
    vector<Point> res(2*n);
    auto cmp = [&](Point p, Point q) -> bool {
        return (abs(p.x - q.x) > EPS ? p.x < q.x : p.y < q.y);
    };
    sort(ps.begin(), ps.end(), cmp);
    int k = 0;
    for (int i = 0; i < n; ++i) {
        if (k >= 2) {
            while (cross(res[k-1] - res[k-2], ps[i] - res[k-2]) < EPS) {
                --k;
                if (k < 2) break;
            }
        }
        res[k] = ps[i]; ++k;
    }
    int t = k+1;
    for (int i = n-2; i >= 0; --i) {
        if (k >= t) {
            while (cross(res[k-1] - res[k-2], ps[i] - res[k-2]) < EPS) {
                --k;
                if (k < t) break;
            }
        }
        res[k] = ps[i]; ++k;
    }
    res.resize(k-1);
    return res;
}
/* 幾何ライブラリここまで */

const int MAX = 10000;
// 座標 x において、多角形の外周または内部に含まれる点の y 座標の最小値と最大値を求める
pair<long long, long long> calc(const vector<Point> &vp, int x) {
    DD minv = INF, maxv = -INF;
    for (int i = 0; i < vp.size(); ++i) {
        Line seg(vp[i], vp[(i+1)%vp.size()]);
        Line tate(Point(x, -1), Point(x, MAX + 1));
        const auto &inter = crosspoint_SS(seg, tate);
        if (inter.empty()) continue;
        chmin(minv, inter[0].y);
        chmax(maxv, inter[0].y);
    }
    long long mi = (long long)(minv - EPS + 1);
    long long ma = (long long)(maxv + EPS);
    return {mi, ma};
}


class NCool {
public:
    int nCoolPoints(vector <int> x, vector <int> y, int N) {
        --N;
        if (N > MAX) return 0;
        
        vector<Point> vp(x.size());
        for (int i = 0; i < x.size(); ++i) vp[i] = Point(x[i], y[i]);
        vp = convex_hull(vp);
        
        long long res = 0;
        for (int i = 0; i < N; ++i) {
            vector<long long> num(N+1, 0);
            
            //cout << "-----------------=" << endl; COUT(i);
            
            // 0 以上 r 未満の値について、num に d 足す
            auto add = [&](int r, long long d) -> void {
                long long q = r / N;
                r %= N;
                num[0] += q * d, num[N] -= q * d;
                num[0] += d, num[r] -= d;
            };
            
            for (int x = i; x <= MAX; x += N) {
                auto [mi, ma] = calc(vp, x);
                if (ma < 0) continue;
                
                //cout << x << ": " << mi << ", " << ma << endl;
                
                add(ma+1, 1), add(mi, -1);
            }
            for (int i = 0; i < N; ++i) num[i+1] += num[i];
            
            //COUT(num);
            for (int i = 0; i < N; ++i) if (num[i] > 1) res += num[i];
        }
        return res;
    }
};



// BEGIN CUT HERE
namespace moj_harness {
    int run_test_case(int);
    void run_test(int casenum = -1, bool quiet = false) {
        if (casenum != -1) {
            if (run_test_case(casenum) == -1 && !quiet) {
                cerr << "Illegal input! Test case " << casenum << " does not exist." << endl;
            }
            return;
        }
        
        int correct = 0, total = 0;
        for (int i=0;i <= 10; ++i) {
            int x = run_test_case(i);
            if (x == -1) {
                if (i >= 100) break;
                continue;
            }
            correct += x;
            ++total;
        }
        
        if (total == 0) {
            cerr << "No test cases run." << endl;
        } else if (correct < total) {
            cerr << "Some cases FAILED (passed " << correct << " of " << total << ")." << endl;
        } else {
            cerr << "All " << total << " tests passed!" << endl;
        }
    }
    
    int verify_case(int casenum, const int &expected, const int &received, clock_t elapsed) { 
        cerr << "Example " << casenum << "... "; 
        
        string verdict;
        vector<string> info;
        char buf[100];
        
        if (elapsed > CLOCKS_PER_SEC / 200) {
            sprintf(buf, "time %.2fs", elapsed * (1.0/CLOCKS_PER_SEC));
            info.push_back(buf);
        }
        
        if (expected == received) {
            verdict = "PASSED";
        } else {
            verdict = "FAILED";
        }
        
        cerr << verdict;
        if (!info.empty()) {
            cerr << " (";
            for (int i=0; i<(int)info.size(); ++i) {
                if (i > 0) cerr << ", ";
                cerr << info[i];
            }
            cerr << ")";
        }
        cerr << endl;
        
        if (verdict == "FAILED") {
            cerr << "    Expected: " << expected << endl; 
            cerr << "    Received: " << received << endl; 
        }
        
        return verdict == "PASSED";
    }

    int run_test_case(int casenum__) {
        switch (casenum__) {
            case 0: {
                int x[]                   = {0, 1, 2, 7, 7};
                int y[]                   = {3, 1, 6, 1, 5};
                int n                     = 6;
                int expected__            = 21;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 1: {
                int x[]                   = {0, 1, 0};
                int y[]                   = {0, 0, 1};
                int n                     = 2;
                int expected__            = 3;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 2: {
                int x[]                   = {0, 0, 1, 2, 2, 1, 0, 0, 2};
                int y[]                   = {0, 1, 2, 2, 1, 0, 0, 0, 2};
                int n                     = 3;
                int expected__            = 6;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 3: {
                int x[]                   = {0, 1, 1, 2, 2, 3, 3, 4, 4, 5};
                int y[]                   = {1, 0, 2, 0, 2, 0, 2, 0, 2, 1};
                int n                     = 5;
                int expected__            = 4;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 4: {
                int x[]                   = {0, 1, 1, 2, 2, 3, 3, 4, 4, 5};
                int y[]                   = {1, 0, 2, 0, 2, 0, 2, 0, 2, 1};
                int n                     = 4;
                int expected__            = 10;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
                
                // custom cases
                
            case 5: {
                int x[]                   = {0, 10000, 10000, 0};
                int y[]                   = {0, 0, 10000, 10000};
                int n                     = 10001;
                int expected__            = 40000;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 6: {
                int x[]                   = {0, 10000, 0};
                int y[]                   = {0, 0, 10000};
                int n                     = 10001;
                int expected__            = 3;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            case 7: {
                int x[]                   = {0, 10000, 10000, 0};
                int y[]                   = {0, 0, 10000, 10000};
                int n                     = 10002;
                int expected__            = 0;
                
                clock_t start__           = clock();
                int received__            = NCool().nCoolPoints(vector <int>(x, x + (sizeof x / sizeof x[0])), vector <int>(y, y + (sizeof y / sizeof y[0])), n);
                return verify_case(casenum__, expected__, received__, clock()-start__);
            }
            default:
                return -1;
        }
    }
}
 

int main(int argc, char *argv[]) {
    if (argc == 1) {
        moj_harness::run_test();
    } else {
        for (int i=1; i<argc; ++i)
            moj_harness::run_test(atoi(argv[i]));
    }
}
// END CUT HERE