Skip to content

Commit eae7067

Browse files
committed
[enh] Minor improvements. E.g. 'python3' -> 'python', encodings possible troubles fixed, README updated (another dataset for faster experiments), requirements.txt added, etc.
1 parent 059e51d commit eae7067

File tree

7 files changed

+27
-24
lines changed

7 files changed

+27
-24
lines changed

custom_format_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def read_amazon_format(path: str, sentence=True):
2727
:param path: a path to a filename
2828
:param sentence: whether to split the reviews into sentences
2929
"""
30-
with open(path + ("" if sentence else "-full_text") + ".txt", "w+") as wf:
30+
with open(path + ("" if sentence else "-full_text") + ".txt", "w+", encoding="utf-8") as wf:
3131

32-
for line in tqdm(open(path)):
32+
for line in tqdm(open(path, "r", encoding="utf-8")):
3333
# reading the text
3434
text = json.loads(line.strip())["reviewText"].replace("\n", " ")
3535
# splitting into sentences

example_run.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
#!/usr/bin/env bash
22

3-
DATA_NAME=reviews_Electronics_5
3+
DATA_NAME=reviews_Cell_Phones_and_Accessories_5
44

55
if [ ! -f ./$DATA_NAME.json.txt ]; then
66
echo "File not found! Downloading..."
77

88
### this may take a while
99
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/$DATA_NAME.json.gz
1010
gunzip $DATA_NAME.json.gz
11-
python3 custom_format_converter.py $DATA_NAME.json
11+
python custom_format_converter.py $DATA_NAME.json
1212
rm $DATA_NAME.json.gz $DATA_NAME.json
1313
mkdir word_vectors
1414
fi
1515

1616
if [ ! -f ./word_vectors/$DATA_NAME.json.txt.w2v ]; then
1717
echo "Training custom word vectors..."
18-
python3 word2vec.py $DATA_NAME.json.txt
18+
python word2vec.py $DATA_NAME.json.txt
1919
fi
2020

2121
echo "Training ABAE..."
22-
python3 main.py -as 30 -d $DATA_NAME.json.txt
22+
python main.py -as 30 -d $DATA_NAME.json.txt -wv word_vectors/$DATA_NAME.json.txt.w2v

main.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
parser.add_argument("--word-vectors-path", "-wv",
1515
dest="wv_path", type=str, metavar='<str>',
16-
default="word_vectors/reviews_Electronics_5.json.txt.w2v",
1716
help="path to word vectors file")
1817

1918
parser.add_argument("--batch-size", "-b", dest="batch_size", type=int, default=50,
@@ -29,7 +28,7 @@
2928
help="Epochs count")
3029

3130
parser.add_argument("--optimizer", "-opt", dest="optimizer", type=str, default="adam", help="Optimizer",
32-
choices=["adam", "adagrad", "sgd"])
31+
choices=["adam", "sgd", "asgd", "adagrad"])
3332

3433
parser.add_argument("--negative-samples", "-ns", dest="neg_samples", type=int, default=5,
3534
help="Negative samples per positive one")
@@ -56,17 +55,14 @@
5655
optimizer = None
5756
scheduler = None
5857

59-
# if args.optimizer == "cycsgd":
60-
# optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)
61-
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=0.05, mode="triangular2")
62-
# elif args.optimizer == "adam":
63-
6458
if args.optimizer == "adam":
6559
optimizer = torch.optim.Adam(model.parameters())
6660
elif args.optimizer == "sgd":
6761
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
6862
elif args.optimizer == "adagrad":
6963
optimizer = torch.optim.Adagrad(model.parameters())
64+
elif args.optimizer == "asgd":
65+
optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)
7066
else:
7167
raise Exception("Optimizer '%s' is not supported" % args.optimizer)
7268

@@ -95,14 +91,13 @@
9591
optimizer.zero_grad()
9692
loss.backward()
9793
optimizer.step()
98-
# scheduler.step(epoch=t)
9994

10095
if item_number % 1000 == 0:
10196

10297
print(item_number, "batches, and LR:", optimizer.param_groups[0]['lr'])
10398

10499
for i, aspect in enumerate(model.get_aspect_words(w2v_model)):
105-
print(i + 1, " ".join(["%10s" % a for a in aspect]))
100+
print(i + 1, " ".join([a for a in aspect]))
106101

107102
print("Loss:", loss.item())
108103
print()

model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77

88
class SelfAttention(torch.nn.Module):
9-
def __init__(self, wv_dim, maxlen):
9+
def __init__(self, wv_dim: int, maxlen: int):
1010
super(SelfAttention, self).__init__()
1111
self.wv_dim = wv_dim
1212

1313
# max sentence length -- batch 2nd dim size
1414
self.maxlen = maxlen
15-
self.M = Parameter(torch.Tensor(wv_dim, wv_dim))
15+
self.M = Parameter(torch.empty(size=(wv_dim, wv_dim)))
1616
init.kaiming_uniform(self.M.data)
1717

1818
# softmax for attending to wod vectors
@@ -44,7 +44,8 @@ class ABAE(torch.nn.Module):
4444
4545
"""
4646

47-
def __init__(self, wv_dim=200, asp_count=30, ortho_reg=0.1, maxlen=201, init_aspects_matrix=None):
47+
def __init__(self, wv_dim: int = 200, asp_count: int = 30,
48+
ortho_reg: float = 0.1, maxlen: int = 201, init_aspects_matrix=None):
4849
"""
4950
Initializing the model
5051
@@ -63,7 +64,7 @@ def __init__(self, wv_dim=200, asp_count=30, ortho_reg=0.1, maxlen=201, init_asp
6364
self.attention = SelfAttention(wv_dim, maxlen)
6465
self.linear_transform = torch.nn.Linear(self.wv_dim, self.asp_count)
6566
self.softmax_aspects = torch.nn.Softmax()
66-
self.aspects_embeddings = Parameter(torch.Tensor(wv_dim, asp_count))
67+
self.aspects_embeddings = Parameter(torch.empty(size=(wv_dim, asp_count)))
6768

6869
if init_aspects_matrix is None:
6970
torch.nn.init.xavier_uniform(self.aspects_embeddings)
@@ -80,8 +81,8 @@ def get_aspects_importances(self, text_embeddings):
8081

8182
# multiplying text embeddings by attention scores -- and summing
8283
# (matmul: we sum every word embedding's coordinate with attention weights)
83-
weighted_text_emb = torch.matmul(attention_weights.unsqueeze(1), # (batch, 1, sentence)
84-
text_embeddings # (batch, sentence, wv_dim)
84+
weighted_text_emb = torch.matmul(attention_weights.unsqueeze(1), # (batch, 1, sentence)
85+
text_embeddings # (batch, sentence, wv_dim)
8586
).squeeze()
8687

8788
# encoding with a simple feed-forward layer (wv_dim) -> (aspects_count)

reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def read_data_batches(path, batch_size=50, minlength=5):
1212
"""
1313
batch = []
1414

15-
for line in open(path):
15+
for line in open(path, encoding="utf-8"):
1616
line = line.strip().split()
1717

1818
# lines with less than `minlength` words are omitted

requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
nltk>=3.5
2+
gensim>=3.8.3
3+
torch>=1.5.0
4+
torchvision>=0.6.0
5+
tqdm>=4.45.0
6+
scikit-learn>=0.22.2.post1
7+
numpy>=1.18.4

word2vec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99

1010
class Sentences(object):
11-
def __init__(self, filename):
11+
def __init__(self, filename: str):
1212
self.filename = filename
1313

1414
def __iter__(self):
15-
for line in tqdm(codecs.open(self.filename, "r", "utf-8"), self.filename):
15+
for line in tqdm(codecs.open(self.filename, "r", encoding="utf-8"), self.filename):
1616
yield line.strip().split()
1717

1818

0 commit comments

Comments
 (0)