Featured image of post [Algorithm] C++ 백준 20176번: Needle

[Algorithm] C++ 백준 20176번: Needle

세 장벽(위·중·아래) 구멍 3점이 일직선이면 중간 x는 위·아래 x의 평균이다. 위/아래 좌표를 0/1 배열로 만든 뒤 FFT 컨볼루션으로 합=2xm 쌍을 세어 O(R log R)에 통로 개수를 구한다.

문제: BOJ 20176 - Needle

세 개의 수평 장벽(위/중/아래)에 뚫린 구멍들의 \(x\)좌표가 주어진다. 바늘이 세 장벽을 동시에 통과하려면 각 장벽에서 하나씩 고른 3개의 구멍이 한 직선 위에 놓여야 한다. 가능한 “통로”의 개수를 세는 문제다.

문제 정보

문제 요약:

  • 위/중/아래 장벽은 서로 수직 정렬되어 있고, 위-중, 중-아래의 \(y\) 간격은 각각 1이다.
  • 각 장벽에는 1개 이상의 구멍이 있으며, 구멍은 \(x\)좌표(정수)로 주어진다. (장벽 내 \(x\)는 모두 서로 다름)
  • 위/중/아래에서 하나씩 고른 3개 구멍이 공선(collinear) 이면 통로 1개로 센다.

제한 조건:

  • 시간 제한: 1초
  • 메모리 제한: 512MB
  • \(1 \le n_u, n_m, n_l \le 50{,}000\)
  • \(x \in [-30{,}000, 30{,}000]\)

접근 방식

핵심 관찰: 공선 조건은 “중간이 평균”

좌표계를 위 장벽 \(y=1\), 중간 장벽 \(y=0\), 아래 장벽 \(y=-1\)로 두자.

세 점 \((x_u, 1), (x_m, 0), (x_l, -1)\) 이 공선이면, 두 구간의 기울기가 같아야 하므로

\[ \frac{0-1}{x_m-x_u} = \frac{-1-0}{x_l-x_m} \Rightarrow x_m-x_u = x_l-x_m \Rightarrow 2x_m = x_u + x_l \]

즉, 중간 \(x_m\)은 위/아래 \(x\)의 정확한 평균(중점) 이어야 한다.

따라서 문제는 다음으로 바뀐다:

  • 각 \(x_m\)에 대해, 위/아래에서 \(x_u + x_l = 2x_m\) 를 만족하는 쌍의 개수를 더하라.

빠른 카운팅: FFT 컨볼루션

범위가 작다(\([-30000,30000]\))는 점을 활용한다.

  • 위 장벽 구멍을 0/1 배열 \(A[i]\) 로 표시 (인덱스는 \(x\)를 SHIFT로 평행이동)
  • 아래 장벽 구멍을 0/1 배열 \(B[i]\) 로 표시
  • 컨볼루션 \(C = A * B\) 에서 \(C[s]\) 는 “SHIFT된 인덱스 합이 \(s\)”인 \((x_u, x_l)\) 쌍의 개수다.

그러면 각 중간 구멍 \(x_m\)에 대해 필요한 합은

\[ (x_u+\text{SHIFT}) + (x_l+\text{SHIFT}) = (x_u+x_l) + 2\text{SHIFT} = 2x_m + 2\text{SHIFT} \]

즉 답은 \(\sum_{x_m \in M} C[2x_m + 2\text{SHIFT}]\) 이다.

알고리즘 설계 (Mermaid Flowchart)

flowchart TD
    A["입력: nu, Ux[]
nm, Mx[]
nl, Lx[]"] --> B["SHIFT=30000으로 U/L을 0/1 배열로 변환"] B --> C["FFT로 컨볼루션 C = U * L 계산"] C --> D["각 중간 구멍 xm에 대해 idx = 2*xm + 2*SHIFT 합산"] D --> E["정답 출력"]

복잡도 분석

좌표 범위 \(R = 60001\) (SHIFT 적용 후 배열 길이)로 보면,

항목복잡도비고
시간 복잡도\(O(R \log R)\)FFT 3회(전방향 2 + 역방향 1)
공간 복잡도\(O(R)\)복소수 배열 및 결과 배열

코너 케이스 및 실수 포인트

케이스설명처리 방법
음수 좌표\(x\)가 음수 가능SHIFT(30000)로 인덱싱
반올림 오차FFT 실수부 오차로 값이 ±1 흔들릴 수 있음llround 사용 + 음수는 0으로 보정
큰 정답합이 커질 수 있음long long으로 누적
배열 길이컨볼루션 길이는 \(2R-1\)FFT size는 그 이상 2의 거듭제곱으로 패딩

구현 코드

C++

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// 42jerrykim.github.io에서 더 많은 정보를 확인 할 수 있다
#include <bits/stdc++.h>
using namespace std;

using cd = complex<double>;
const double PI = acos(-1.0);

static void fft(vector<cd>& a, bool invert) {
    int n = (int)a.size();
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * PI / len * (invert ? -1 : 1);
        cd wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            cd w(1);
            for (int j = 0; j < len / 2; j++) {
                cd u = a[i + j];
                cd v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }

    if (invert) {
        for (cd& x : a) x /= n;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    const int SHIFT = 30000;
    const int MAXX = 60000;             // shifted x range: [0..60000]
    const int LEN = MAXX + 1;           // 60001
    const int CONV_LEN = 2 * MAXX + 1;  // 120001

    int nu, nm, nl;
    cin >> nu;
    vector<int> U(LEN, 0), L(LEN, 0);
    for (int i = 0; i < nu; i++) {
        int x; cin >> x;
        U[x + SHIFT] = 1;
    }

    cin >> nm;
    vector<int> mids(nm);
    for (int i = 0; i < nm; i++) cin >> mids[i];

    cin >> nl;
    for (int i = 0; i < nl; i++) {
        int x; cin >> x;
        L[x + SHIFT] = 1;
    }

    int n = 1;
    while (n < CONV_LEN) n <<= 1;

    vector<cd> a(n), b(n);
    for (int i = 0; i < LEN; i++) {
        a[i] = (double)U[i];
        b[i] = (double)L[i];
    }

    fft(a, false);
    fft(b, false);
    for (int i = 0; i < n; i++) a[i] *= b[i];
    fft(a, true);

    vector<long long> conv(CONV_LEN, 0);
    for (int i = 0; i < CONV_LEN; i++) {
        long long v = llround(a[i].real());
        if (v < 0) v = 0; // numeric noise guard
        conv[i] = v;
    }

    long long ans = 0;
    for (int xm : mids) {
        int idx = 2 * xm + 2 * SHIFT;
        if (0 <= idx && idx < CONV_LEN) ans += conv[idx];
    }

    cout << ans << '\n';
    return 0;
}