CF607E Cross Sum 题解
Visits: 318
题意
- 在平面直角坐标系中,给定 $n$ 条不同的直线,第 $i$ 条直线为 $y=a_i x+b_i$。
- 可重点集 $\mathcal I$ 为 $n$ 条直线两两的交点,可重数集 $\mathcal D$ 为 $\mathcal I$ 中所有点与 $(p,q)$ 的距离。
- 你需要求出 $\mathcal D$ 中前 $m$ 小的数之和。
- $n \le 5 \times 10^4$,$m \le 3 \times 10^7$,$|p|,|q|,|a_i|,|b_i| \le 10^3$ 且输入的整数值为实际值的 $10^3$ 倍,答案精度误差 $\le 10^{-6}$。
题解
为了方便,我们可以直接将 $(p,q)$ 作为坐标系原点,则每条直线可以转成 $y+q=a_i(x+p)+b_i$,也就是 $y=a_ix + a_ip+b_i-q$。
首先显然要找到一个半径最小的圆使其能够覆盖至少 $m$ 个交点。
考虑二分答案转化为判定问题。我们可以求出每条直线与圆的交点坐标并极角排序和离散化,问题就变成若干条线段有多少两两相交,是经典的扫描线问题,可以用树状数组求出。这部分时间复杂度 $\mathcal O(n \log n \log w)$。
找到这个半径之后,我们现在要计算答案。
与上面的过程类似,但由于此时交点数最多只有 $m$,因此我们可以用 set 代替树状数组,直接得到所有相交的线段,然后求出它们的交点与原点的距离。
但可能有多个点恰好在圆上,因此我们在统计距离的时候只统计严格在圆内的点,如果个数少于 $m$,则剩下的距离全都是半径。这部分时间复杂度 $\mathcal O(m \log n)$。
总时间复杂度 $\mathcal O(n \log n \log w + m \log n)$,需要注意一下精度问题。
代码
const int N = 5e4 + 7;
int n, m, x, y, t, v[N], c[N*2];
ld p, q, a[N], b[N];
pair<ld, int> s[N*2];
inline void add(int x) {
while (x <= t) ++c[x], x += x & -x;
}
inline int ask(int x) {
int k = 0;
while (x) k += c[x], x -= x & -x;
return k;
}
inline void work(ld r) {
t = 0;
for (int i = 1; i <= n; i++) {
ld a = ::a[i], b = ::b[i];
ld d = 4 * a * a * b * b - 4 * (a * a + 1) * (b * b - r * r);
if (d < eps) continue;
d = sqrt(d);
ld x1 = (-2 * a * b - d) / (2 * (a * a + 1));
ld x2 = (-2 * a * b + d) / (2 * (a * a + 1));
ld y1 = a * x1 + b, y2 = a * x2 + b;
s[++t] = mp(P(x1, y1).a(), i), s[++t] = mp(P(x2, y2).a(), i);
v[i] = 0;
}
sort(s + 1, s + t + 1);
for (int i = t; i; i--)
if (!v[s[i].se]) v[s[i].se] = i;
}
inline bool pd(ld r) {
work(r);
for (int i = 1; i <= t; i++) c[i] = 0;
ll ret = 0;
for (int i = 1; i <= t; i++)
if (v[s[i].se] != i) {
ret += ask(v[s[i].se]) - ask(i - 1);
add(v[s[i].se]);
}
return ret >= m;
}
inline ld calc(int i, int j) {
ld x = (b[j] - b[i]) / (a[i] - a[j]), y = a[i] * x + b[i];
return sqrt(x * x + y * y);
}
inline ld calc(ld r) {
work(r);
set<pi> st;
ld ret = 0;
int cnt = 0;
for (int i = 1; i <= t; i++)
if (v[s[i].se] != i) {
auto it = st.upper_bound(mp(i + 1, 0));
while (it != st.end() && (it -> fi) < v[s[i].se])
ret += calc(s[i].se, it -> se), ++it, ++cnt;
st.insert(mp(v[s[i].se], s[i].se));
}
return ret + (m - cnt) * r;
}
int main() {
rd(n), rd(x), rd(y), rd(m), p = x / 1e3, q = y / 1e3;
for (int i = 1; i <= n; i++)
rd(x), rd(y), a[i] = x / 1e3, b[i] = y / 1e3 + a[i] * p - q;
ld l = 0, r = 1e10;
for (int i = 1; i <= 100; i++) {
ld mid = (l + r) / 2;
if (pd(mid)) r = mid;
else l = mid;
}
printf("%.10Lf\n", max(0.0L, calc(r)));
return 0;
}