1
1
from collections .abc import Set , Iterator , Mapping
2
- from typing import List , TypeVar , Union , Any
2
+ from typing import List , TypeVar , Union , Any , Callable , Optional , Generator
3
3
from abc import abstractmethod , ABC
4
- from functools import total_ordering
5
4
6
5
T = TypeVar ("T" )
7
6
V = TypeVar ("V" )
8
7
K = TypeVar ("K" )
9
8
10
9
11
- @total_ordering
12
- class Comparable :
13
- def __init__ (self , value ):
14
- self .value = value
10
+ class BinarySearchUtil :
11
+ @staticmethod
12
+ def binary_search (
13
+ data , item , compare_func : Optional [Callable [[Any , Any ], int ]] = None
14
+ ) -> int :
15
+ low , high = 0 , len (data ) - 1
16
+ while low <= high :
17
+ mid = (low + high ) // 2
18
+ mid_val = data [mid ] if not isinstance (data , ListBackedSet ) else data [mid ]
19
+ comparison = (
20
+ compare_func (mid_val , item )
21
+ if compare_func
22
+ else (mid_val > item ) - (mid_val < item )
23
+ )
15
24
16
- def __lt__ (self , other : Any ) -> bool :
17
- if not isinstance (other , Comparable ):
18
- return NotImplemented
19
- return self .value < other .value
25
+ if comparison < 0 :
26
+ low = mid + 1
27
+ elif comparison > 0 :
28
+ high = mid - 1
29
+ else :
30
+ return mid # item found
31
+ return - (low + 1 ) # item not found
20
32
21
- def __eq__ (self , other : Any ) -> bool :
22
- if not isinstance (other , Comparable ):
23
- return NotImplemented
24
- return self .value == other .value
33
+ @staticmethod
34
+ def default_compare (a : Any , b : Any ) -> int :
35
+ """Default comparison function for binary search, with special handling for strings."""
36
+ if isinstance (a , str ) and isinstance (b , str ):
37
+ a , b = a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " )
38
+ return (a > b ) - (a < b )
25
39
26
40
27
41
class ListBackedSet (Set [T ], ABC ):
@@ -31,25 +45,14 @@ def __len__(self) -> int: ...
31
45
@abstractmethod
32
46
def __getitem__ (self , index : Union [int , slice ]) -> Union [T , List [T ]]: ...
33
47
48
+ @abstractmethod
49
+ def __iter__ (self ) -> Iterator [T ]: ...
50
+
34
51
def __contains__ (self , item : Any ) -> bool :
35
52
return self ._binary_search (item ) >= 0
36
53
37
54
def _binary_search (self , item : Any ) -> int :
38
- low = 0
39
- high = len (self ) - 1
40
- while low <= high :
41
- mid = (low + high ) // 2
42
- try :
43
- mid_val = self [mid ]
44
- if mid_val < item :
45
- low = mid + 1
46
- elif mid_val > item :
47
- high = mid - 1
48
- else :
49
- return mid # item found
50
- except TypeError :
51
- raise ValueError (f"Cannot compare items due to a type mismatch." )
52
- return - (low + 1 ) # item not found
55
+ return BinarySearchUtil .binary_search (self , item )
53
56
54
57
55
58
class ArraySet (ListBackedSet [K ]):
@@ -80,59 +83,49 @@ def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
80
83
return self .__data [index ]
81
84
82
85
def plusOrThis (self , element : K ) -> "ArraySet[K]" :
83
- if element in self :
86
+ index = self ._binary_search (element )
87
+ if index >= 0 :
84
88
return self
85
89
else :
90
+ insert_at = - (index + 1 )
86
91
new_data = self .__data [:]
87
- new_data .append (element )
88
- new_data .sort (key = Comparable )
92
+ new_data .insert (insert_at , element )
89
93
return ArraySet .__create (new_data )
90
94
91
95
92
96
class ArrayMap (Mapping [K , V ]):
93
- def __init__ (self , data = None ):
94
- if data is None :
95
- self .__data = []
96
- else :
97
- self .__data = data
97
+ __data : List [Union [K , V ]]
98
+
99
+ def __init__ (self ):
100
+ raise NotImplementedError ("Use ArrayMap.empty() or other class methods instead" )
101
+
102
+ @classmethod
103
+ def __create (cls , data : List [Union [K , V ]]) -> "ArrayMap[K, V]" :
104
+ instance = cls .__new__ (cls )
105
+ instance .__data = data
106
+ return instance
98
107
99
108
@classmethod
100
109
def empty (cls ) -> "ArrayMap[K, V]" :
101
110
if not hasattr (cls , "__EMPTY" ):
102
- cls .__EMPTY = cls ([])
111
+ cls .__EMPTY = cls . __create ([])
103
112
return cls .__EMPTY
104
113
105
114
def __getitem__ (self , key : K ) -> V :
106
115
index = self ._binary_search_key (key )
107
116
if index >= 0 :
108
- return self .__data [2 * index + 1 ]
117
+ return self .__data [2 * index + 1 ] # type: ignore
109
118
raise KeyError (key )
110
119
111
120
def __iter__ (self ) -> Iterator [K ]:
112
- return (self .__data [i ] for i in range (0 , len (self .__data ), 2 ))
121
+ return (self .__data [i ] for i in range (0 , len (self .__data ), 2 )) # type: ignore
113
122
114
123
def __len__ (self ) -> int :
115
124
return len (self .__data ) // 2
116
125
117
126
def _binary_search_key (self , key : K ) -> int :
118
- def compare (a , b ):
119
- """Comparator that puts '/' first in strings."""
120
- if isinstance (a , str ) and isinstance (b , str ):
121
- a , b = a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " )
122
- return (a > b ) - (a < b )
123
-
124
- low , high = 0 , len (self .__data ) // 2 - 1
125
- while low <= high :
126
- mid = (low + high ) // 2
127
- mid_key = self .__data [2 * mid ]
128
- comparison = compare (mid_key , key )
129
- if comparison < 0 :
130
- low = mid + 1
131
- elif comparison > 0 :
132
- high = mid - 1
133
- else :
134
- return mid # key found
135
- return - (low + 1 ) # key not found
127
+ keys = [self .__data [i ] for i in range (0 , len (self .__data ), 2 )]
128
+ return BinarySearchUtil .binary_search (keys , key )
136
129
137
130
def plus (self , key : K , value : V ) -> "ArrayMap[K, V]" :
138
131
index = self ._binary_search_key (key )
@@ -142,12 +135,12 @@ def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
142
135
new_data = self .__data [:]
143
136
new_data .insert (insert_at * 2 , key )
144
137
new_data .insert (insert_at * 2 + 1 , value )
145
- return ArrayMap (new_data )
138
+ return ArrayMap . __create (new_data )
146
139
147
140
def minus_sorted_indices (self , indices : List [int ]) -> "ArrayMap[K, V]" :
148
141
new_data = self .__data [:]
149
142
adjusted_indices = [i * 2 for i in indices ] + [i * 2 + 1 for i in indices ]
150
- adjusted_indices .sort ()
151
- for index in reversed ( adjusted_indices ) :
143
+ adjusted_indices .sort (reverse = True )
144
+ for index in adjusted_indices :
152
145
del new_data [index ]
153
- return ArrayMap (new_data )
146
+ return ArrayMap . __create (new_data )
0 commit comments