1
1
from collections .abc import Set , Iterator , Mapping
2
- from typing import List , TypeVar , Union
2
+ from typing import List , TypeVar , Union , Any
3
3
from abc import abstractmethod , ABC
4
4
5
5
T = TypeVar ("T" )
6
6
V = TypeVar ("V" )
7
7
K = TypeVar ("K" )
8
8
9
9
10
+ def _compare_normal (a , b ) -> int :
11
+ if a == b :
12
+ return 0
13
+ elif a < b :
14
+ return - 1
15
+ else :
16
+ return 1
17
+
18
+
19
+ def _compare_string_slash_first (a : str , b : str ) -> int :
20
+ return _compare_normal (a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " ))
21
+
22
+
23
+ def _binary_search (data , item ) -> int :
24
+ compare_func = (
25
+ _compare_string_slash_first if isinstance (item , str ) else _compare_normal
26
+ )
27
+ low , high = 0 , len (data ) - 1
28
+ while low <= high :
29
+ mid = (low + high ) // 2
30
+ mid_val = data [mid ]
31
+ comparison = compare_func (mid_val , item )
32
+
33
+ if comparison < 0 :
34
+ low = mid + 1
35
+ elif comparison > 0 :
36
+ high = mid - 1
37
+ else :
38
+ return mid # item found
39
+ return - (low + 1 ) # item not found
40
+
41
+
10
42
class ListBackedSet (Set [T ], ABC ):
11
43
@abstractmethod
12
44
def __len__ (self ) -> int : ...
13
45
14
46
@abstractmethod
15
47
def __getitem__ (self , index : Union [int , slice ]) -> Union [T , List [T ]]: ...
16
48
17
- def __contains__ (self , item : object ) -> bool :
18
- for i in range (len (self )):
19
- if self [i ] == item :
20
- return True
21
- return False
49
+ @abstractmethod
50
+ def __iter__ (self ) -> Iterator [T ]: ...
51
+
52
+ def __contains__ (self , item : Any ) -> bool :
53
+ return self ._binary_search (item ) >= 0
54
+
55
+ def _binary_search (self , item : Any ) -> int :
56
+ return _binary_search (self , item )
22
57
23
58
24
59
class ArraySet (ListBackedSet [K ]):
25
60
__data : List [K ]
26
61
27
- def __init__ (self , data : List [ K ] ):
28
- raise NotImplementedError ("Use ArraySet.empty() instead" )
62
+ def __init__ (self ):
63
+ raise NotImplementedError ("Use ArraySet.empty() or other class methods instead" )
29
64
30
65
@classmethod
31
66
def __create (cls , data : List [K ]) -> "ArraySet[K]" :
32
- # Create a new instance without calling __init__
33
67
instance = super ().__new__ (cls )
34
68
instance .__data = data
35
69
return instance
@@ -40,82 +74,74 @@ def __iter__(self) -> Iterator[K]:
40
74
@classmethod
41
75
def empty (cls ) -> "ArraySet[K]" :
42
76
if not hasattr (cls , "__EMPTY" ):
43
- cls .__EMPTY = cls ([])
77
+ cls .__EMPTY = cls . __create ([])
44
78
return cls .__EMPTY
45
79
46
80
def __len__ (self ) -> int :
47
81
return len (self .__data )
48
82
49
83
def __getitem__ (self , index : Union [int , slice ]) -> Union [K , List [K ]]:
50
- if isinstance (index , int ):
51
- return self .__data [index ]
52
- elif isinstance (index , slice ):
53
- return self .__data [index ]
54
- else :
55
- raise TypeError ("Invalid argument type." )
84
+ return self .__data [index ]
56
85
57
86
def plusOrThis (self , element : K ) -> "ArraySet[K]" :
58
- # TODO: use binary search, and also special sort order for strings
59
- if element in self . __data :
87
+ index = self . _binary_search ( element )
88
+ if index >= 0 :
60
89
return self
61
90
else :
91
+ insert_at = - (index + 1 )
62
92
new_data = self .__data [:]
63
- new_data .append (element )
64
- new_data .sort () # type: ignore[reportOperatorIssue]
93
+ new_data .insert (insert_at , element )
65
94
return ArraySet .__create (new_data )
66
95
67
96
68
97
class ArrayMap (Mapping [K , V ]):
69
- def __init__ (self , data : list ):
70
- # TODO: hide this constructor as done in ArraySet
71
- self .__data = data
98
+ __data : List [Union [K , V ]]
99
+
100
+ def __init__ (self ):
101
+ raise NotImplementedError ("Use ArrayMap.empty() or other class methods instead" )
102
+
103
+ @classmethod
104
+ def __create (cls , data : List [Union [K , V ]]) -> "ArrayMap[K, V]" :
105
+ instance = cls .__new__ (cls )
106
+ instance .__data = data
107
+ return instance
72
108
73
109
@classmethod
74
110
def empty (cls ) -> "ArrayMap[K, V]" :
75
111
if not hasattr (cls , "__EMPTY" ):
76
- cls .__EMPTY = cls ([])
112
+ cls .__EMPTY = cls . __create ([])
77
113
return cls .__EMPTY
78
114
79
115
def __getitem__ (self , key : K ) -> V :
80
- index = self .__binary_search_key (key )
116
+ index = self ._binary_search_key (key )
81
117
if index >= 0 :
82
- return self .__data [2 * index + 1 ]
118
+ return self .__data [2 * index + 1 ] # type: ignore
83
119
raise KeyError (key )
84
120
85
121
def __iter__ (self ) -> Iterator [K ]:
86
- return (self .__data [i ] for i in range (0 , len (self .__data ), 2 ))
122
+ return (self .__data [i ] for i in range (0 , len (self .__data ), 2 )) # type: ignore
87
123
88
124
def __len__ (self ) -> int :
89
125
return len (self .__data ) // 2
90
126
91
- def __binary_search_key (self , key : K ) -> int :
92
- # TODO: special sort order for strings
93
- low , high = 0 , (len (self .__data ) // 2 ) - 1
94
- while low <= high :
95
- mid = (low + high ) // 2
96
- mid_key = self .__data [2 * mid ]
97
- if mid_key < key :
98
- low = mid + 1
99
- elif mid_key > key :
100
- high = mid - 1
101
- else :
102
- return mid
103
- return - (low + 1 )
127
+ def _binary_search_key (self , key : K ) -> int :
128
+ keys = [self .__data [i ] for i in range (0 , len (self .__data ), 2 )]
129
+ return _binary_search (keys , key )
104
130
105
131
def plus (self , key : K , value : V ) -> "ArrayMap[K, V]" :
106
- index = self .__binary_search_key (key )
132
+ index = self ._binary_search_key (key )
107
133
if index >= 0 :
108
134
raise ValueError ("Key already exists" )
109
135
insert_at = - (index + 1 )
110
136
new_data = self .__data [:]
111
- new_data [insert_at * 2 : insert_at * 2 ] = [key , value ]
112
- return ArrayMap (new_data )
137
+ new_data .insert (insert_at * 2 , key )
138
+ new_data .insert (insert_at * 2 + 1 , value )
139
+ return ArrayMap .__create (new_data )
113
140
114
- def minus_sorted_indices (self , indicesToRemove : List [int ]) -> "ArrayMap[K, V]" :
115
- if not indicesToRemove :
116
- return self
117
- newData = []
118
- for i in range (0 , len (self .__data ), 2 ):
119
- if i // 2 not in indicesToRemove :
120
- newData .extend (self .__data [i : i + 2 ])
121
- return ArrayMap (newData )
141
+ def minus_sorted_indices (self , indices : List [int ]) -> "ArrayMap[K, V]" :
142
+ new_data = self .__data [:]
143
+ adjusted_indices = [i * 2 for i in indices ] + [i * 2 + 1 for i in indices ]
144
+ adjusted_indices .sort (reverse = True )
145
+ for index in adjusted_indices :
146
+ del new_data [index ]
147
+ return ArrayMap .__create (new_data )
0 commit comments