Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Commit 1845272

Browse files
committedApr 3, 2024
Centralize binary search
1 parent fe16c26 commit 1845272

File tree

1 file changed

+55
-62
lines changed

1 file changed

+55
-62
lines changed
 
+55-62
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,41 @@
11
from collections.abc import Set, Iterator, Mapping
2-
from typing import List, TypeVar, Union, Any
2+
from typing import List, TypeVar, Union, Any, Callable, Optional, Generator
33
from abc import abstractmethod, ABC
4-
from functools import total_ordering
54

65
T = TypeVar("T")
76
V = TypeVar("V")
87
K = TypeVar("K")
98

109

11-
@total_ordering
12-
class Comparable:
13-
def __init__(self, value):
14-
self.value = value
10+
class BinarySearchUtil:
11+
@staticmethod
12+
def binary_search(
13+
data, item, compare_func: Optional[Callable[[Any, Any], int]] = None
14+
) -> int:
15+
low, high = 0, len(data) - 1
16+
while low <= high:
17+
mid = (low + high) // 2
18+
mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid]
19+
comparison = (
20+
compare_func(mid_val, item)
21+
if compare_func
22+
else (mid_val > item) - (mid_val < item)
23+
)
1524

16-
def __lt__(self, other: Any) -> bool:
17-
if not isinstance(other, Comparable):
18-
return NotImplemented
19-
return self.value < other.value
25+
if comparison < 0:
26+
low = mid + 1
27+
elif comparison > 0:
28+
high = mid - 1
29+
else:
30+
return mid # item found
31+
return -(low + 1) # item not found
2032

21-
def __eq__(self, other: Any) -> bool:
22-
if not isinstance(other, Comparable):
23-
return NotImplemented
24-
return self.value == other.value
33+
@staticmethod
34+
def default_compare(a: Any, b: Any) -> int:
35+
"""Default comparison function for binary search, with special handling for strings."""
36+
if isinstance(a, str) and isinstance(b, str):
37+
a, b = a.replace("/", "\0"), b.replace("/", "\0")
38+
return (a > b) - (a < b)
2539

2640

2741
class ListBackedSet(Set[T], ABC):
@@ -31,25 +45,14 @@ def __len__(self) -> int: ...
3145
@abstractmethod
3246
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...
3347

48+
@abstractmethod
49+
def __iter__(self) -> Iterator[T]: ...
50+
3451
def __contains__(self, item: Any) -> bool:
3552
return self._binary_search(item) >= 0
3653

3754
def _binary_search(self, item: Any) -> int:
38-
low = 0
39-
high = len(self) - 1
40-
while low <= high:
41-
mid = (low + high) // 2
42-
try:
43-
mid_val = self[mid]
44-
if mid_val < item:
45-
low = mid + 1
46-
elif mid_val > item:
47-
high = mid - 1
48-
else:
49-
return mid # item found
50-
except TypeError:
51-
raise ValueError(f"Cannot compare items due to a type mismatch.")
52-
return -(low + 1) # item not found
55+
return BinarySearchUtil.binary_search(self, item)
5356

5457

5558
class ArraySet(ListBackedSet[K]):
@@ -80,59 +83,49 @@ def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
8083
return self.__data[index]
8184

8285
def plusOrThis(self, element: K) -> "ArraySet[K]":
83-
if element in self:
86+
index = self._binary_search(element)
87+
if index >= 0:
8488
return self
8589
else:
90+
insert_at = -(index + 1)
8691
new_data = self.__data[:]
87-
new_data.append(element)
88-
new_data.sort(key=Comparable)
92+
new_data.insert(insert_at, element)
8993
return ArraySet.__create(new_data)
9094

9195

9296
class ArrayMap(Mapping[K, V]):
93-
def __init__(self, data=None):
94-
if data is None:
95-
self.__data = []
96-
else:
97-
self.__data = data
97+
__data: List[Union[K, V]]
98+
99+
def __init__(self):
100+
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")
101+
102+
@classmethod
103+
def __create(cls, data: List[Union[K, V]]) -> "ArrayMap[K, V]":
104+
instance = cls.__new__(cls)
105+
instance.__data = data
106+
return instance
98107

99108
@classmethod
100109
def empty(cls) -> "ArrayMap[K, V]":
101110
if not hasattr(cls, "__EMPTY"):
102-
cls.__EMPTY = cls([])
111+
cls.__EMPTY = cls.__create([])
103112
return cls.__EMPTY
104113

105114
def __getitem__(self, key: K) -> V:
106115
index = self._binary_search_key(key)
107116
if index >= 0:
108-
return self.__data[2 * index + 1]
117+
return self.__data[2 * index + 1] # type: ignore
109118
raise KeyError(key)
110119

111120
def __iter__(self) -> Iterator[K]:
112-
return (self.__data[i] for i in range(0, len(self.__data), 2))
121+
return (self.__data[i] for i in range(0, len(self.__data), 2)) # type: ignore
113122

114123
def __len__(self) -> int:
115124
return len(self.__data) // 2
116125

117126
def _binary_search_key(self, key: K) -> int:
118-
def compare(a, b):
119-
"""Comparator that puts '/' first in strings."""
120-
if isinstance(a, str) and isinstance(b, str):
121-
a, b = a.replace("/", "\0"), b.replace("/", "\0")
122-
return (a > b) - (a < b)
123-
124-
low, high = 0, len(self.__data) // 2 - 1
125-
while low <= high:
126-
mid = (low + high) // 2
127-
mid_key = self.__data[2 * mid]
128-
comparison = compare(mid_key, key)
129-
if comparison < 0:
130-
low = mid + 1
131-
elif comparison > 0:
132-
high = mid - 1
133-
else:
134-
return mid # key found
135-
return -(low + 1) # key not found
127+
keys = [self.__data[i] for i in range(0, len(self.__data), 2)]
128+
return BinarySearchUtil.binary_search(keys, key)
136129

137130
def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
138131
index = self._binary_search_key(key)
@@ -142,12 +135,12 @@ def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
142135
new_data = self.__data[:]
143136
new_data.insert(insert_at * 2, key)
144137
new_data.insert(insert_at * 2 + 1, value)
145-
return ArrayMap(new_data)
138+
return ArrayMap.__create(new_data)
146139

147140
def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]":
148141
new_data = self.__data[:]
149142
adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices]
150-
adjusted_indices.sort()
151-
for index in reversed(adjusted_indices):
143+
adjusted_indices.sort(reverse=True)
144+
for index in adjusted_indices:
152145
del new_data[index]
153-
return ArrayMap(new_data)
146+
return ArrayMap.__create(new_data)

0 commit comments

Comments
 (0)