Skip to content

Commit 549ad53

Browse files
Add implementation of AdaptHD and zero-norm warning (#165)
* Add AdaptHD centroid update rule, and fix #120 * [github-action] formatting fixes * Add test * Simplify google drive download * Add gdown to dev dependencies * [github-action] formatting fixes * Update tests * Fix test * Simpler intrvfl setup --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 53828b8 commit 549ad53

14 files changed

+114
-130
lines changed

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ pytest
99
black
1010
tqdm
1111
openpyxl
12-
coverage
12+
coverage
13+
gdown

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
https://packaging.python.org/guides/distributing-packages-using-setuptools/
44
https://github.com/pypa/sampleproject
55
"""
6+
67
from setuptools import setup, find_packages
78

89
# Read the version without importing any dependencies

torchhd/datasets/utils.py

Lines changed: 8 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -23,93 +23,24 @@
2323
#
2424
import zipfile
2525
import requests
26-
import re
2726
import tqdm
2827

29-
# Code adapted from:
30-
# https://github.com/wkentaro/gdown/blob/941200a9a1f4fd7ab903fb595baa5cad34a30a45/gdown/download.py
31-
# https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
32-
3328

3429
def download_file(url, destination):
3530
response = requests.get(url, allow_redirects=True, stream=True)
3631
write_response_to_disk(response, destination)
3732

3833

3934
def download_file_from_google_drive(file_id, destination):
40-
URL = "https://docs.google.com/uc"
41-
params = dict(id=file_id, export="download")
42-
43-
with requests.Session() as session:
44-
response = session.get(URL, params=params, stream=True)
45-
46-
# downloads right away
47-
if "Content-Disposition" in response.headers:
48-
write_response_to_disk(response, destination)
49-
return
50-
51-
# try to find a confirmation token
52-
token = get_google_drive_confirm_token(response)
53-
54-
if token:
55-
params = dict(id=id, confirm=token)
56-
response = session.get(URL, params=params, stream=True)
57-
58-
# download if confirmation token worked
59-
if "Content-Disposition" in response.headers:
60-
write_response_to_disk(response, destination)
61-
return
62-
63-
# extract download url from confirmation page
64-
url = get_url_from_gdrive_confirmation(response.text)
65-
response = session.get(url, stream=True)
66-
67-
write_response_to_disk(response, destination)
68-
69-
70-
def get_google_drive_confirm_token(response):
71-
for key, value in response.cookies.items():
72-
if key.startswith("download_warning"):
73-
return value
74-
75-
return None
76-
77-
78-
def get_url_from_gdrive_confirmation(contents):
79-
url = ""
80-
for line in contents.splitlines():
81-
m = re.search(r'href="(\/uc\?export=download[^"]+)', line)
82-
if m:
83-
url = "https://docs.google.com" + m.groups()[0]
84-
url = url.replace("&amp;", "&")
85-
break
86-
m = re.search('id="downloadForm" action="(.+?)"', line)
87-
if m:
88-
url = m.groups()[0]
89-
url = url.replace("&amp;", "&")
90-
break
91-
m = re.search('id="download-form" action="(.+?)"', line)
92-
if m:
93-
url = m.groups()[0]
94-
url = url.replace("&amp;", "&")
95-
break
96-
m = re.search('"downloadUrl":"([^"]+)', line)
97-
if m:
98-
url = m.groups()[0]
99-
url = url.replace("\\u003d", "=")
100-
url = url.replace("\\u0026", "&")
101-
break
102-
m = re.search('<p class="uc-error-subcaption">(.*)</p>', line)
103-
if m:
104-
error = m.groups()[0]
105-
raise RuntimeError(error)
106-
if not url:
107-
raise RuntimeError(
108-
"Cannot retrieve the public link of the file. "
109-
"You may need to change the permission to "
110-
"'Anyone with the link', or have had many accesses."
35+
try:
36+
import gdown
37+
except ImportError:
38+
raise ImportError(
39+
"Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown"
11140
)
112-
return url
41+
42+
url = f"https://drive.google.com/uc?id={file_id}"
43+
gdown.download(url, destination)
11344

11445

11546
def get_download_progress_bar(response):

torchhd/models.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,8 @@
2828
from torch import Tensor
2929
from torch.nn.parameter import Parameter
3030
import torch.nn.init as init
31-
import torch.utils.data as data
32-
from tqdm import tqdm
33-
3431

3532
import torchhd.functional as functional
36-
import torchhd.datasets as datasets
3733
import torchhd.embeddings as embeddings
3834

3935

@@ -71,6 +67,7 @@ class Centroid(nn.Module):
7167
>>> output.size()
7268
torch.Size([128, 30])
7369
"""
70+
7471
__constants__ = ["in_features", "out_features"]
7572
in_features: int
7673
out_features: int
@@ -108,6 +105,30 @@ def add(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
108105
"""Adds the input vectors scaled by the lr to the target prototype vectors."""
109106
self.weight.index_add_(0, target, input, alpha=lr)
110107

108+
@torch.no_grad()
109+
def add_adapt(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
110+
r"""Only updates the prototype vectors on wrongly predicted inputs.
111+
112+
Implements the iterative training method as described in `AdaptHD: Adaptive Efficient Training for Brain-Inspired Hyperdimensional Computing <https://ieeexplore.ieee.org/document/8918974>`_.
113+
114+
Subtracts the input from the mispredicted class prototype scaled by the learning rate
115+
and adds the input to the target prototype scaled by the learning rate.
116+
"""
117+
logit = self(input)
118+
pred = logit.argmax(1)
119+
is_wrong = target != pred
120+
121+
# cancel update if all predictions were correct
122+
if is_wrong.sum().item() == 0:
123+
return
124+
125+
input = input[is_wrong]
126+
target = target[is_wrong]
127+
pred = pred[is_wrong]
128+
129+
self.weight.index_add_(0, target, input, alpha=lr)
130+
self.weight.index_add_(0, pred, input, alpha=-lr)
131+
111132
@torch.no_grad()
112133
def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
113134
r"""Only updates the prototype vectors on wrongly predicted inputs.
@@ -137,23 +158,30 @@ def add_online(self, input: Tensor, target: Tensor, lr: float = 1.0) -> None:
137158
alpha1 = 1.0 - logit.gather(1, target.unsqueeze(1))
138159
alpha2 = logit.gather(1, pred.unsqueeze(1)) - 1.0
139160

140-
self.weight.index_add_(0, target, lr * alpha1 * input)
141-
self.weight.index_add_(0, pred, lr * alpha2 * input)
161+
self.weight.index_add_(0, target, alpha1 * input, alpha=lr)
162+
self.weight.index_add_(0, pred, alpha2 * input, alpha=lr)
142163

143-
@torch.no_grad()
144164
def normalize(self, eps=1e-12) -> None:
145165
"""Transforms all the class prototype vectors into unit vectors.
146166
147167
After calling this, inferences can be made more efficiently by specifying ``dot=True`` in the forward pass.
148168
Training further after calling this method is not advised.
149169
"""
150170
norms = self.weight.norm(dim=1, keepdim=True)
171+
172+
if torch.isclose(norms, torch.zeros_like(norms), equal_nan=True).any():
173+
import warnings
174+
175+
warnings.warn(
176+
"The norm of a prototype vector is nearly zero upon normalizing, this could indicate a bug."
177+
)
178+
151179
norms.clamp_(min=eps)
152180
self.weight.div_(norms)
153181

154182
def extra_repr(self) -> str:
155183
return "in_features={}, out_features={}".format(
156-
self.in_features, self.out_features is not None
184+
self.in_features, self.out_features
157185
)
158186

159187

torchhd/structures.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,10 @@ class Multiset:
186186
@overload
187187
def __init__(
188188
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
189-
):
190-
...
189+
): ...
191190

192191
@overload
193-
def __init__(self, input: VSATensor, *, size=0):
194-
...
192+
def __init__(self, input: VSATensor, *, size=0): ...
195193

196194
def __init__(self, dim_or_input: Any, vsa: VSAOptions = "MAP", **kwargs):
197195
self.size = kwargs.get("size", 0)
@@ -334,12 +332,10 @@ class HashTable:
334332
@overload
335333
def __init__(
336334
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
337-
):
338-
...
335+
): ...
339336

340337
@overload
341-
def __init__(self, input: VSATensor, *, size=0):
342-
...
338+
def __init__(self, input: VSATensor, *, size=0): ...
343339

344340
def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
345341
self.size = kwargs.get("size", 0)
@@ -501,12 +497,10 @@ class BundleSequence:
501497
@overload
502498
def __init__(
503499
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
504-
):
505-
...
500+
): ...
506501

507502
@overload
508-
def __init__(self, input: VSATensor, *, size=0):
509-
...
503+
def __init__(self, input: VSATensor, *, size=0): ...
510504

511505
def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
512506
self.size = kwargs.get("size", 0)
@@ -693,12 +687,10 @@ class BindSequence:
693687
@overload
694688
def __init__(
695689
self, dimensions: int, vsa: VSAOptions = "MAP", *, device=None, dtype=None
696-
):
697-
...
690+
): ...
698691

699692
@overload
700-
def __init__(self, input: VSATensor, *, size=0):
701-
...
693+
def __init__(self, input: VSATensor, *, size=0): ...
702694

703695
def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
704696
self.size = kwargs.get("size", 0)
@@ -861,12 +853,10 @@ def __init__(
861853
directed=False,
862854
device=None,
863855
dtype=None
864-
):
865-
...
856+
): ...
866857

867858
@overload
868-
def __init__(self, input: VSATensor, *, directed=False):
869-
...
859+
def __init__(self, input: VSATensor, *, directed=False): ...
870860

871861
def __init__(self, dim_or_input: int, vsa: VSAOptions = "MAP", **kwargs):
872862
self.is_directed = kwargs.get("directed", False)

torchhd/tensors/bsbc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class BSBCTensor(VSATensor):
3636
3737
Because the vectors are sparse and have a fixed magnitude, we only represent the index of the non-zero value.
3838
"""
39+
3940
block_size: int
4041
supported_dtypes: Set[torch.dtype] = {
4142
torch.float32,

torchhd/tensors/fhrr.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,5 +395,12 @@ def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:
395395
else:
396396
magnitude = self_mag * others_mag
397397

398+
if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
399+
import warnings
400+
401+
warnings.warn(
402+
"The norm of a vector is nearly zero, this could indicate a bug."
403+
)
404+
398405
magnitude = torch.clamp(magnitude, min=eps)
399406
return self.dot_similarity(others) / magnitude

torchhd/tensors/hrr.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,5 +382,12 @@ def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:
382382
else:
383383
magnitude = self_mag * others_mag
384384

385+
if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
386+
import warnings
387+
388+
warnings.warn(
389+
"The norm of a vector is nearly zero, this could indicate a bug."
390+
)
391+
385392
magnitude = torch.clamp(magnitude, min=eps)
386393
return self.dot_similarity(others) / magnitude

torchhd/tensors/map.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,5 +368,12 @@ def cosine_similarity(
368368
else:
369369
magnitude = self_mag * others_mag
370370

371+
if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
372+
import warnings
373+
374+
warnings.warn(
375+
"The norm of a vector is nearly zero, this could indicate a bug."
376+
)
377+
371378
magnitude = torch.clamp(magnitude, min=eps)
372379
return self.dot_similarity(others, dtype=dtype) / magnitude

torchhd/tensors/vtb.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,5 +411,12 @@ def cosine_similarity(self, others: "VTBTensor", *, eps=1e-08) -> Tensor:
411411
else:
412412
magnitude = self_mag * others_mag
413413

414+
if torch.isclose(magnitude, torch.zeros_like(magnitude), equal_nan=True).any():
415+
import warnings
416+
417+
warnings.warn(
418+
"The norm of a vector is nearly zero, this could indicate a bug."
419+
)
420+
414421
magnitude = torch.clamp(magnitude, min=eps)
415422
return self.dot_similarity(others) / magnitude

0 commit comments

Comments
 (0)