Skip to content

Commit bd15b33

Browse files
authored
update to match the camer-ready version
1 parent 39ca570 commit bd15b33

10 files changed

+310
-52
lines changed

README.md

+63-33
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
11
# On Sparse Modern Hopfield Model
2-
This is the code of the paper [On Sparse Modern Hopfield Model](https://arxiv.org/pdf/2309.12673.pdf). You can use this repo to reproduce the results in the paper.
3-
4-
## Citations
5-
Please consider citing our paper in your publications if it helps. Here is the bibtex:
6-
7-
```
8-
@inproceedings{
9-
hu2023on,
10-
title={On Sparse Modern Hopfield Model},
11-
author={Jerry Yao-Chieh Hu and Donglin Yang and Dennis Wu and Chenwei Xu and Bo-Yu Chen and Han Liu},
12-
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
13-
year={2023},
14-
url={https://openreview.net/forum?id=eCgWNU2Imw}
15-
}
16-
```
2+
This is the code of the paper [On Sparse Modern Hopfield Model](https://arxiv.org/pdf/2309.12673.pdf). You can use this repo to reproduce the results of our method.
173

184
## Environmental Setup
195

@@ -25,12 +11,61 @@ $ conda activate sparse_hopfield
2511
$ pip3 install -r requirements.txt
2612
```
2713

14+
## Examples
15+
16+
In ```layers.py```, we have implemented the general sparse Hopfield, dense Hopfield and sparse Hopfield.
17+
To use it, see below
18+
19+
```python
20+
dense_hp = HopfieldPooling(
21+
d_model=d_model,
22+
n_heads=n_heads,
23+
mix=True,
24+
update_steps=update_steps,
25+
dropout=dropout,
26+
mode="softmax",
27+
scale=scale,
28+
num_pattern=num_pattern) # Dense Hopfield
29+
30+
sparse_hp = HopfieldPooling(
31+
d_model=d_model,
32+
n_heads=n_heads,
33+
mix=True,
34+
update_steps=update_steps,
35+
dropout=dropout,
36+
mode="sparsemax",
37+
scale=scale,
38+
num_pattern=num_pattern) # Sparse Hopfield
39+
40+
entmax_hp = HopfieldPooling(
41+
d_model=d_model,
42+
n_heads=n_heads,
43+
mix=True,
44+
update_steps=update_steps,
45+
dropout=dropout,
46+
mode="entmax",
47+
scale=scale,
48+
num_pattern=num_pattern) # Hopfield with Entmax-15
49+
50+
gsh_hp = HopfieldPooling(
51+
d_model=d_model,
52+
n_heads=n_heads,
53+
mix=True,
54+
update_steps=update_steps,
55+
dropout=dropout,
56+
mode="gsh",
57+
scale=scale,
58+
num_pattern=num_pattern) # Generalized Sparse Hopfield with learnable alpha
59+
```
60+
61+
2862
## Experimental Validation of Theoretical Results
2963

3064
### Plotting
3165

3266
```shell
33-
$ python3 Plotting.py
67+
$ cd theoretical_results_validation
68+
$ python3 plotting.py
3469
```
3570

3671
## Multiple Instance Learning(MIL) Tasks
@@ -128,21 +163,16 @@ Argument options
128163
* `gpus_per_trial`: how many gpus do u want to use for a single run (set this up carefully for hyperparameter tuning) (no larger than 1)
129164
* `gpus_id`: specify which gpus u want to use (e.g. `--gpus_id=0, 1` means cuda:0 and cuda:1 are used for this script)
130165

166+
## Citations
167+
Please consider citing our paper in your publications if it helps. Here is the bibtex:
131168

132-
## Acknowledgment
133-
134-
The authors would like to thank the anonymous reviewers and program chairs for constructive comments.
135-
136-
JH is partially supported by the Walter P. Murphy Fellowship.
137-
HL is partially supported by NIH R01LM1372201, NSF CAREER1841569, DOE DE-AC02-07CH11359, DOE LAB 20-2261 and a NSF TRIPODS1740735.
138-
This research was supported in part through the computational resources and staff contributions provided for the Quest high performance computing facility at Northwestern University which is jointly supported by the Office of the Provost, the Office for Research, and Northwestern University Information Technology.
139-
The content is solely the responsibility of the authors and does not necessarily represent the official
140-
views of the funding agencies.
141-
142-
The experiments in this work benefit from the following open-source codes:
143-
* Ramsauer, Hubert, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Thomas Adler, Lukas Gruber et al. "Hopfield networks is all you need." arXiv preprint arXiv:2008.02217 (2020). https://github.com/ml-jku/hopfield-layers
144-
* Martins, Andre, and Ramon Astudillo. "From softmax to sparsemax: A sparse model of attention and multi-label classification." In International conference on machine learning, pp. 1614-1623. PMLR, 2016. https://github.com/KrisKorrel/sparsemax-pytorch
145-
* Correia, Gonçalo M., Vlad Niculae, and André FT Martins. "Adaptively sparse transformers." arXiv preprint arXiv:1909.00015 (2019). https://github.com/deep-spin/entmax & https://github.com/prajjwal1/adaptive_transformer
146-
* Ilse, Maximilian, Jakub Tomczak, and Max Welling. "Attention-based deep multiple instance learning." In International conference on machine learning, pp. 2127-2136. PMLR, 2018. https://github.com/AMLab-Amsterdam/AttentionDeepMIL
147-
* Zhang, Yunhao, and Junchi Yan. "Crossformer: Transformer utilizing cross-dimension dependency for multivariate time series forecasting." In The Eleventh International Conference on Learning Representations. 2022. https://github.com/Thinklab-SJTU/Crossformer
148-
* Millidge, Beren, Tommaso Salvatori, Yuhang Song, Thomas Lukasiewicz, and Rafal Bogacz. "Universal hopfield networks: A general framework for single-shot associative memory models." In International Conference on Machine Learning, pp. 15561-15583. PMLR, 2022. https://github.com/BerenMillidge/Theory_Associative_Memory
169+
```
170+
@misc{hu2023sparse,
171+
title={On Sparse Modern Hopfield Model},
172+
author={Jerry Yao-Chieh Hu and Donglin Yang and Dennis Wu and Chenwei Xu and Bo-Yu Chen and Han Liu},
173+
year={2023},
174+
eprint={2309.12673},
175+
archivePrefix={arXiv},
176+
primaryClass={cs.LG}
177+
}
178+
```

imgs.zip

1.31 MB
Binary file not shown.

layers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from math import sqrt
88
from utils.sparse_max import Sparsemax
99
from utils.entmax import Entmax15
10-
10+
from utils.general_entmax import EntmaxAlpha
1111

1212
class FullAttention(nn.Module):
1313
'''
@@ -101,6 +101,8 @@ def __init__(self, scale=None, attention_dropout=0.0, mode='sparsemax', norm=Fal
101101
self.softmax = Sparsemax(dim=-1)
102102
elif mode == 'entmax':
103103
self.softmax = Entmax15(dim=-1)
104+
elif mode == 'gsh':
105+
self.softmax = EntmaxAlpha(dim=-1)
104106
else:
105107
self.softmax = nn.Softmax(dim=-1)
106108

@@ -274,7 +276,7 @@ def __init__(
274276

275277
self.ln = nn.LayerNorm(d_model, elementwise_affine=False)
276278

277-
if mode in ["sparsemax", "softmax", "entmax"]:
279+
if mode in ["sparsemax", "softmax", "entmax", "gsh"]:
278280
self.inner_attention = HopfieldCore(
279281
scale=scale, attention_dropout=dropout, mode=mode, norm=True)
280282

mnist_mil_main.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_args():
1515

1616
# Model params
1717
parser.add_argument('--mode', default="softmax", choices=["softmax", "entmax", "sparsemax"])
18-
parser.add_argument('--d_model', default=256, type=int)
18+
parser.add_argument('--d_model', default=512, type=int)
1919
parser.add_argument('--input_size', default=784, type=int)
2020
parser.add_argument('--model', default="pooling", type=str)
2121
parser.add_argument('--num_pattern', default=2, type=int)
@@ -25,7 +25,7 @@ def get_args():
2525
parser.add_argument('--dropout', default=0.3, type=float)
2626

2727
# Training params
28-
parser.add_argument('--lr', default=1e-3, type=float)
28+
parser.add_argument('--lr', default=1e-4, type=float)
2929
parser.add_argument('--epoch', default=100, type=int)
3030
parser.add_argument('--seed', default=1111, type=int)
3131

@@ -37,7 +37,6 @@ def get_args():
3737
parser.add_argument('--bag_size', default=10, type=int)
3838
parser.add_argument('--tgt_num', default=9, type=int)
3939

40-
4140
args = parser.parse_args()
4241

4342
return vars(args)
@@ -47,19 +46,22 @@ def get_args():
4746

4847
torch.set_num_threads(3)
4948
config = get_args()
50-
trails = 1
49+
trails = 5
5150
torch.manual_seed(config["seed"])
5251

52+
53+
5354
if config["bag_size"] == 100:
5455
config["num_pattern"] = 4
5556
bag_size = config["bag_size"]
5657
# bag_size = [5, 10, 20, 50, 100, 200, 300]
57-
models = ["softmax", "sparsemax", "entmax"]
58+
models = ["softmax", "sparsemax", "entmax", "gsh"]
5859
data_log = None
5960

6061
for m in models:
6162
config["mode"] = m
6263
for t in range(trails):
64+
torch.random.manual_seed(torch.random.seed())
6365
trainer = Trainer(config, t)
6466
trail_log = trainer.train()
6567
if data_log is None:
@@ -68,22 +70,22 @@ def get_args():
6870
for k,v in data_log.items():
6971
data_log[k] = data_log[k] + trail_log[k]
7072

71-
sns.lineplot(data=data_log, x="epoch", y="train loss", hue="model")
73+
sns.lineplot(data=data_log, x="epoch", y="train loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
7274
plt.tight_layout()
73-
plt.savefig(f'./imgs/train_loss_{bag_size}.png')
75+
plt.savefig(f'./imgs/train_loss_{bag_size}.pdf')
7476
plt.clf()
7577

76-
sns.lineplot(data=data_log, x="epoch", y="test loss", hue="model")
78+
sns.lineplot(data=data_log, x="epoch", y="test loss", hue="model", alpha=0.4, errorbar=None, linewidth=2)
7779
plt.tight_layout()
78-
plt.savefig(f'./imgs/test_loss_{bag_size}.png')
80+
plt.savefig(f'./imgs/test_loss_{bag_size}.pdf')
7981
plt.clf()
8082

81-
sns.lineplot(data=data_log, x="epoch", y="train acc", hue="model")
83+
sns.lineplot(data=data_log, x="epoch", y="train acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
8284
plt.tight_layout()
83-
plt.savefig(f'./imgs/train_acc_{bag_size}.png')
85+
plt.savefig(f'./imgs/train_acc_{bag_size}.pdf')
8486
plt.clf()
8587

86-
sns.lineplot(data=data_log, x="epoch", y="test acc", hue="model")
88+
sns.lineplot(data=data_log, x="epoch", y="test acc", hue="model", alpha=0.4, errorbar=None, linewidth=2)
8789
plt.tight_layout()
88-
plt.savefig(f'./imgs/test_acc_{bag_size}.png')
90+
plt.savefig(f'./imgs/test_acc_{bag_size}.pdf')
8991
plt.clf()

mnist_mil_trainer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def _get_data(self):
4343

4444
def _get_model(self):
4545

46-
model = MNISTModel(input_size=self.config["input_size"],
47-
d_model=self.config["d_model"],
46+
model = CIFARModel(d_model=self.config["d_model"],
4847
n_heads=self.config["n_heads"],
4948
update_steps=self.config["update_steps"],
5049
dropout=self.config["dropout"],
@@ -55,7 +54,7 @@ def _get_model(self):
5554
return model.cuda()
5655

5756
def _get_opt(self):
58-
return torch.optim.AdamW(self.model.parameters(), lr=self.config["lr"])
57+
return torch.optim.AdamW(self.model.parameters(), lr=self.config["lr"], weight_decay=0.001)
5958

6059
def _get_cri(self):
6160
return torch.nn.BCEWithLogitsLoss()

models.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
self.ln = nn.LayerNorm(d_model)
2222
self.ln2 = nn.LayerNorm(d_model)
2323

24-
if mode in ["sparsemax", 'softmax', 'entmax']:
24+
if mode in ["sparsemax", 'softmax', 'entmax', 'gsh']:
2525
self.layer = HopfieldPooling(
2626
d_model=d_model,
2727
n_heads=n_heads,
@@ -42,4 +42,52 @@ def forward(self, x):
4242
x = self.ln(self.emb(x))
4343
out = self.ln2(self.gelu(self.layer(x)))
4444
out = out.view(bz, -1)
45+
return self.fc(out).squeeze(-1)
46+
47+
48+
49+
50+
class CIFARModel(nn.Module):
51+
def __init__(
52+
self,
53+
d_model=256,
54+
n_heads=4,
55+
update_steps=1,
56+
dropout=0.1,
57+
mode='softmax',
58+
scale=None,
59+
num_pattern=1):
60+
super(CIFARModel, self).__init__()
61+
62+
assert d_model % n_heads == 0
63+
self.emb = nn.Conv2d(3, 6, 5)
64+
self.pool = nn.MaxPool2d(4, 2)
65+
self.linear = nn.Linear(1014, d_model)
66+
67+
self.ln = nn.LayerNorm(d_model)
68+
self.ln2 = nn.LayerNorm(d_model)
69+
70+
if mode in ["sparsemax", 'softmax', 'entmax', 'gsh']:
71+
self.layer = HopfieldPooling(
72+
d_model=d_model,
73+
n_heads=n_heads,
74+
mix=True,
75+
update_steps=update_steps,
76+
dropout=dropout,
77+
mode=mode,
78+
scale=scale,
79+
num_pattern=num_pattern)
80+
81+
self.fc = nn.Linear(d_model*num_pattern, 1)
82+
self.gelu = nn.GELU()
83+
84+
def forward(self, x):
85+
86+
bz, N, c, h, w = x.size()
87+
x = x.view(bz*N, c, h, w)
88+
x = self.pool(self.emb(x))
89+
x = x.view(bz, N, -1)
90+
x = self.ln(self.linear(x))
91+
out = self.ln2(self.gelu(self.layer(x)))
92+
out = out.view(bz, -1)
4593
return self.fc(out).squeeze(-1)

run.sh

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
python3 mnist_mil_main.py --bag_size 20
2+
python3 mnist_mil_main.py --bag_size 50
3+
python3 mnist_mil_main.py --bag_size 5
4+
python3 mnist_mil_main.py --bag_size 10
5+
python3 mnist_mil_main.py --bag_size 100
6+
python3 mnist_mil_main.py --bag_size 30
7+
python3 mnist_mil_main.py --bag_size 80
0 Bytes
Binary file not shown.
5.18 KB
Binary file not shown.

0 commit comments

Comments
 (0)