Skip to content

Commit 80a6e26

Browse files
committed
[enh] Hydra configs fixes, README, LICENSE and example_run.sh script updated.
1 parent c35f05b commit 80a6e26

File tree

8 files changed

+19
-17
lines changed

8 files changed

+19
-17
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 Anton Alekseev
3+
Copyright (c) 2019-2021 Anton Alekseev
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Yet another PyTorch implementation of the model described in the paper [**An Unsupervised Neural Attention Model for Aspect Extraction**](https://aclweb.org/anthology/papers/P/P17/P17-1036/) by He, Ruidan and Lee, Wee Sun and Ng, Hwee Tou and Dahlmeier, Daniel, **ACL2017**.
44

5-
**NOTA BENE**: now `gensim>=4.0.0` and `hydra` are required.
5+
**NOTA BENE**: as of August 2021, `gensim>=4.1.0` and `hydra-core>=1.1.0` are required.
66

77
## Example
88

@@ -23,13 +23,12 @@ python3 word2vec.py reviews_Cell_Phones_and_Accessories_5.json.txt
2323
```
2424
And run
2525

26-
**TODO**: running with hydra params example is in progress
2726
```
28-
usage: main.py ...
29-
27+
python main.py model.aspects_number=35 data.path=$DATA_NAME.json.txt model.log_progress_steps=1000
3028
```
3129

32-
For a working example of a whole pipeline please refer to `example_run.sh`
30+
Please see all passable params in the `configs/` directory. For a working example of a whole pipeline
31+
please refer to `example_run.sh`
3332

3433
I acknowledge the implementation is raw, code modification requests and issues are welcome.
3534

configs/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- embeddings: word2vec-custom
33
- optimizer: adam
4+
- _self_
45

56
data:
67
path: "reviews_Cell_Phones_and_Accessories_5.json.txt"
@@ -12,7 +13,8 @@ model:
1213
epochs: 1
1314
negative_samples: 5
1415
max_len: 201
16+
log_progress_steps: 1000
1517

1618
hydra:
1719
run:
18-
dir: . #results/sessions_${now:%Y-%m-%d}_${now:%H-%M-%S}
20+
dir: results/sessions_${now:%Y-%m-%d}_${now:%H-%M-%S}

custom_format_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
stops = set(stopwords.words("english"))
1414

1515

16-
@lru_cache(1000000000)
16+
@lru_cache(maxsize=1000000000)
1717
def lemmatize(w: str):
1818
# caching the word-based lemmatizer to speed the process up
1919
return lemmatizer.lemmatize(w)
@@ -29,7 +29,7 @@ def read_amazon_format(path: str, sentence=True):
2929
"""
3030
with open(path + ("" if sentence else "-full_text") + ".txt", "w+", encoding="utf-8") as wf:
3131

32-
for line in tqdm(open(path, "r", encoding="utf-8")):
32+
for line in tqdm(open(path, "r", encoding="utf-8"), "normalizing texts read from [%s]" % path):
3333
# reading the text
3434
text = json.loads(line.strip())["reviewText"].replace("\n", " ")
3535
# splitting into sentences

example_run.sh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@ if [ ! -f ./$DATA_NAME.json.txt ]; then
88
### this may take a while
99
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/$DATA_NAME.json.gz
1010
gunzip $DATA_NAME.json.gz
11-
python custom_format_converter.py $DATA_NAME.json
11+
python3 custom_format_converter.py $DATA_NAME.json
1212
rm $DATA_NAME.json.gz $DATA_NAME.json
1313
mkdir word_vectors
1414
fi
1515

1616
if [ ! -f ./word_vectors/$DATA_NAME.json.txt.w2v ]; then
1717
echo "Training custom word vectors..."
18-
python word2vec.py $DATA_NAME.json.txt
18+
python3 word2vec.py $DATA_NAME.json.txt
1919
fi
2020

2121
echo "Training ABAE..."
22-
echo "A working example is in progress... Please see 'main.py' code."
23-
#python main.py -as 30 -d $DATA_NAME.json.txt -wv word_vectors/$DATA_NAME.json.txt.w2v
22+
python3 main.py model.aspects_number=35 data.path=$DATA_NAME.json.txt model.log_progress_steps=1000

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hydra
55
import numpy as np
66
import torch
7+
import os
78

89
from model import ABAE
910
from reader import get_centroids, get_w2v, read_data_tensors
@@ -13,7 +14,7 @@
1314

1415
@hydra.main("configs", "config")
1516
def main(cfg):
16-
w2v_model = get_w2v(cfg.embeddings.path)
17+
w2v_model = get_w2v(os.path.join(hydra.utils.get_original_cwd(), cfg.embeddings.path))
1718
wv_dim = w2v_model.vector_size
1819
y = torch.zeros((cfg.model.batch_size, 1))
1920

@@ -39,7 +40,8 @@ def main(cfg):
3940

4041
logger.debug("Epoch %d/%d" % (t + 1, cfg.model.epochs))
4142

42-
data_iterator = read_data_tensors(cfg.data.path, cfg.embeddings.path,
43+
data_iterator = read_data_tensors(os.path.join(hydra.utils.get_original_cwd(), cfg.data.path),
44+
os.path.join(hydra.utils.get_original_cwd(), cfg.embeddings.path),
4345
batch_size=cfg.model.batch_size, maxlen=cfg.model.max_len)
4446

4547
for item_number, (x, texts) in enumerate(data_iterator):
@@ -62,7 +64,7 @@ def main(cfg):
6264
loss.backward()
6365
optimizer.step()
6466

65-
if item_number % 1000 == 0:
67+
if item_number % cfg.model.log_progress_steps == 0:
6668

6769
logger.info("%d batches, and LR: %.5f" % (item_number, optimizer.param_groups[0]['lr']))
6870

word2vec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __iter__(self):
1919
def main(path):
2020
sentences = Sentences(path)
2121
model = gensim.models.Word2Vec(sentences, vector_size=200, window=5, min_count=5, workers=7, sg=1,
22-
negative=5, iter=1, max_vocab_size=20000)
22+
negative=5, max_vocab_size=20000)
2323
model.save("word_vectors/" + path + ".w2v")
2424
# model.wv.save_word2vec_format("word_vectors/" + domain + ".txt", binary=False)
2525

word_vectors/DO_NOT_README

Whitespace-only changes.

0 commit comments

Comments
 (0)