Skip to content

Commit 009a50c

Browse files
committed
Fix numpy arrays warning
1 parent 0d9878c commit 009a50c

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

examples/learning_with_hrr.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchhd import embeddings, HRRTensor
1717
import torchhd.tensors
1818
from scipy.sparse import vstack, lil_matrix
19+
import numpy as np
1920

2021

2122
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -36,7 +37,7 @@ def sparse_batch_collate(batch:list):
3637
data_batch, targets_batch = zip(*batch)
3738

3839
data_batch = vstack(data_batch).tocoo()
39-
data_batch = torch.sparse_coo_tensor(data_batch.nonzero(), data_batch.data, data_batch.shape)
40+
data_batch = torch.sparse_coo_tensor(np.array(data_batch.nonzero()), data_batch.data, data_batch.shape)
4041

4142
targets_batch = torch.stack(targets_batch)
4243

@@ -67,7 +68,7 @@ def __getitem__(self, idx):
6768
if DATASET_NAME == "Wiki10-31K": # Because of this issue https://github.com/mwydmuch/napkinXC/issues/18
6869
X_train = lil_matrix(X_train[:,:-1])
6970

70-
N_freatures = X_train.shape[1]
71+
N_features = X_train.shape[1]
7172
N_classes = max(max(classes) for classes in Y_train if classes != []) + 1
7273

7374
train_dataset = multilabel_dataset(X_train,Y_train,N_classes)
@@ -77,7 +78,7 @@ def __getitem__(self, idx):
7778

7879

7980
print("Traning on \033[1m {} \033[0m. It has {} features, and {} classes."
80-
.format(DATASET_NAME,N_freatures,N_classes))
81+
.format(DATASET_NAME,N_features,N_classes))
8182

8283

8384
# Fully Connected model for the baseline comparision
@@ -168,10 +169,10 @@ def loss(self,out,target):
168169

169170

170171

171-
hrr_model = FCHRR(N_freatures,N_classes,DIMENSIONS)
172+
hrr_model = FCHRR(N_features,N_classes,DIMENSIONS)
172173
hrr_model = hrr_model.to(device)
173174

174-
baseline_model = FC(N_freatures,N_classes)
175+
baseline_model = FC(N_features,N_classes)
175176
baseline_model = baseline_model.to(device)
176177

177178

0 commit comments

Comments
 (0)