Skip to content

Commit 54b6465

Browse files
authored
Add MCRTensor (#174)
* Add MCRTensor * Add MCRTensor header to doc files * Simplify the implementation of bundle
1 parent 132ee17 commit 54b6465

14 files changed

+570
-110
lines changed

docs/torchhd.rst

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ VSA Models
8787
HRRTensor
8888
FHRRTensor
8989
BSBCTensor
90+
MCRTensor
9091
VTBTensor
9192

9293

torchhd/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torchhd.tensors.fhrr import FHRRTensor
3838
from torchhd.tensors.bsbc import BSBCTensor
3939
from torchhd.tensors.vtb import VTBTensor
40+
from torchhd.tensors.mcr import MCRTensor
4041

4142
from torchhd.functional import (
4243
ensure_vsa_tensor,
@@ -90,6 +91,7 @@
9091
"FHRRTensor",
9192
"BSBCTensor",
9293
"VTBTensor",
94+
"MCRTensor",
9395
"functional",
9496
"embeddings",
9597
"structures",

torchhd/functional.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchhd.tensors.fhrr import FHRRTensor
3636
from torchhd.tensors.bsbc import BSBCTensor
3737
from torchhd.tensors.vtb import VTBTensor
38+
from torchhd.tensors.mcr import MCRTensor
3839
from torchhd.types import VSAOptions
3940

4041

@@ -90,6 +91,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
9091
return BSBCTensor
9192
elif vsa == "VTB":
9293
return VTBTensor
94+
elif vsa == "MCR":
95+
return MCRTensor
9396

9497
raise ValueError(f"Provided VSA model is not supported, specified: {vsa}")
9598

@@ -358,7 +361,7 @@ def level(
358361
device=span_hv.device,
359362
).as_subclass(vsa_tensor)
360363

361-
if vsa == "BSBC":
364+
if vsa == "BSBC" or vsa == "MCR":
362365
hv.block_size = span_hv.block_size
363366

364367
for i in range(num_vectors):
@@ -585,7 +588,7 @@ def circular(
585588
device=span_hv.device,
586589
).as_subclass(vsa_tensor)
587590

588-
if vsa == "BSBC":
591+
if vsa == "BSBC" or vsa == "MCR":
589592
hv.block_size = span_hv.block_size
590593

591594
mutation_history = deque()

0 commit comments

Comments
 (0)