1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  | # 더 많은 정보는 42jerrykim.github.io 에서 확인하세요.
import sys
sys.setrecursionlimit(1_000_000)
input = sys.stdin.readline
MAX_BIT = 29
class BinaryTrie:
    __slots__ = ("nxt",)
    def __init__(self):
        # nxt: list of [child0, child1]
        self.nxt = [[-1, -1]]
    def insert(self, value: int, max_bit: int) -> None:
        cur = 0
        for b in range(max_bit, -1, -1):
            bit = (value >> b) & 1
            if self.nxt[cur][bit] == -1:
                self.nxt[cur][bit] = len(self.nxt)
                self.nxt.append([-1, -1])
            cur = self.nxt[cur][bit]
    def query_min_xor(self, value: int, max_bit: int) -> int:
        cur = 0
        cost = 0
        for b in range(max_bit, -1, -1):
            bit = (value >> b) & 1
            prefer = bit
            nxt = self.nxt[cur][prefer]
            if nxt == -1:
                prefer ^= 1
                nxt = self.nxt[cur][prefer]
                cost |= (1 << b)
            cur = nxt
        return cost
def solve_group(arr, bit):
    if bit < 0 or len(arr) <= 1:
        return 0
    left, right = [], []
    for x in arr:
        if (x >> bit) & 1:
            right.append(x)
        else:
            left.append(x)
    res = 0
    if left:
        res += solve_group(left, bit - 1)
    if right:
        res += solve_group(right, bit - 1)
    if left and right:
        lower_bit = bit - 1
        if lower_bit < 0:
            res += (1 << bit)
        else:
            trie = BinaryTrie()
            for y in right:
                trie.insert(y, lower_bit)
            best_lower = 1 << 30
            for x in left:
                best_lower = min(best_lower, trie.query_min_xor(x, lower_bit))
            res += (1 << bit) + best_lower
    return res
def main():
    N = int(input().strip())
    A = list(map(int, input().split()))
    print(solve_group(A, MAX_BIT))
if __name__ == "__main__":
    main()
  |