Skip to content

Commit fc09b28

Browse files
committed
Simplify the implementation of bundle
1 parent ac08413 commit fc09b28

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

torchhd/tensors/mcr.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,20 @@ def random(
216216
result = result.as_subclass(cls)
217217
result.block_size = block_size
218218
return result
219+
220+
def to_complex_unit(self):
221+
angles = 2 * torch.pi * self / self.block_size
222+
return torch.polar(torch.ones_like(self, dtype=angles.dtype), angles)
219223

220-
def bundle(self, other: "MCRTensor", *, generator=None) -> "MCRTensor":
221-
r"""Bundle the hypervector with 2d vactor addition.
224+
def bundle(self, other: "MCRTensor") -> "MCRTensor":
225+
r"""Bundle the hypervector with normalized complex vector addition.
222226
223227
This produces a hypervector maximally similar to both.
224228
225229
The bundling operation is used to aggregate information into a single hypervector.
226230
227231
Args:
228232
other (MCR): other input hypervector
229-
generator (``torch.Generator``, optional): a pseudorandom number generator for sampling.
230233
231234
Shapes:
232235
- Self: :math:`(*)`
@@ -245,45 +248,42 @@ def bundle(self, other: "MCRTensor", *, generator=None) -> "MCRTensor":
245248
246249
"""
247250
assert self.block_size == other.block_size
251+
252+
self_phasor = self.to_complex_unit()
253+
other_phasor = other.to_complex_unit()
248254

249-
# Building a search table to make the process of finding
250-
# the position of elements in 2d faster.
251-
# This search table could be generated just one time (when instantiating the VSA)
252-
search_table = torch.Tensor([[torch.sin(torch.pi*2*i/self.block_size),
253-
torch.cos(torch.pi*2*i/self.block_size)]
254-
for i in torch.arange(0,self.block_size,1)])
255-
search_table = (search_table*1000).round()/1000 # Round to the nearest thousandth
256-
257-
# We changed types because the float numbers cannot be used for indexing.
258-
self_in_2d = search_table[torch.Tensor(self.type(torch.int64))].squeeze()
259-
other_in_2d = search_table[torch.Tensor(other.type(torch.int64))].squeeze()
260-
261-
# Adding the vectors of each element and normalizing it
262-
sum_in_2d = self_in_2d + other_in_2d
263-
normalized_sum_in_2d = sum_in_2d.swapaxes(-1,-2)/ sum_in_2d.norm(dim=-1)
255+
# Adding the vectors of each element
256+
sum_of_phasors = self_phasor + other_phasor
264257

265258
# To define the ultimate number that the summation will land on
266259
# we first find the theta of summation then quantize it to block_size
267-
angels = torch.arctan2(normalized_sum_in_2d[0],normalized_sum_in_2d[1])
268-
result = self.block_size*(angels/(2*torch.pi))
269-
260+
angels = torch.angle(sum_of_phasors)
261+
result = self.block_size * (angels / (2 * torch.pi))
262+
270263
# In cases where the two elements are inverse of each other
271-
# the sum will be (0,0) and it makes the final result to be nan.
264+
# the sum will be 0 + 0j and it makes the final result to be nan.
272265
# We return the average of two operands in such a case.
273-
result = result.where(~result.isnan(),(self+other)/2).round()
266+
is_zero = torch.isclose(sum_of_phasors, torch.zeros_like(sum_of_phasors))
267+
result = torch.where(is_zero, (self + other) / 2, result).round()
274268

275269
return torch.remainder(result, self.block_size).type(self.dtype)
276270

277271
def multibundle(self) -> "MCRTensor":
278272
"""Bundle multiple hypervectors"""
279-
elements_in_2d = torch.Tensor([[torch.sin(torch.pi*2*i/self.block_size),torch.cos(torch.pi*2*i/self.block_size)]
280-
for i in torch.arange(0,self.block_size,1)])
281-
self_in_2d = elements_in_2d[torch.Tensor(self.type(torch.int64))]
282-
sum_in_2d = self_in_2d.sum(-3)
283-
normalized_sum_in_2d = sum_in_2d.swapaxes(-1,-2)/ sum_in_2d.norm(dim=-1)
284-
angels = torch.arctan(normalized_sum_in_2d[0]/normalized_sum_in_2d[1])
285-
result = self.block_size*(angels/torch.pi)
286-
result = result.where(~result.isnan(),self.sum(-2)/self.size(-2)).round()
273+
274+
self_phasor = self.to_complex_unit()
275+
sum_of_phasors = torch.sum(self_phasor, dim=-2)
276+
277+
# To define the ultimate number that the summation will land on
278+
# we first find the theta of summation then quantize it to block_size
279+
angels = torch.angle(sum_of_phasors)
280+
result = self.block_size * (angels / (2 * torch.pi))
281+
282+
# In cases where the two elements are inverse of each other
283+
# the sum will be 0 + 0j and it makes the final result to be nan.
284+
# We return the average of two operands in such a case.
285+
is_zero = torch.isclose(sum_of_phasors, torch.zeros_like(sum_of_phasors))
286+
result = torch.where(is_zero, torch.mean(self, dim=-2, dtype=torch.float), result).round()
287287

288288
return torch.remainder(result, self.block_size).type(self.dtype)
289289

0 commit comments

Comments
 (0)