1
1
# -*- coding: utf-8 -*-
2
2
import numpy as np
3
3
import torch
4
+ import hydra
4
5
from model import ABAE
5
6
from reader import get_centroids , get_w2v , read_data_tensors
6
7
7
8
8
- if __name__ == "__main__" :
9
-
10
- import argparse
11
-
12
- parser = argparse .ArgumentParser ()
13
-
14
- parser .add_argument ("--word-vectors-path" , "-wv" ,
15
- dest = "wv_path" , type = str , metavar = '<str>' ,
16
- help = "path to word vectors file" )
17
-
18
- parser .add_argument ("--batch-size" , "-b" , dest = "batch_size" , type = int , default = 50 ,
19
- help = "Batch size for training" )
20
-
21
- parser .add_argument ("--aspects-number" , "-as" , dest = "aspects_number" , type = int , default = 40 ,
22
- help = "A total number of aspects" )
23
-
24
- parser .add_argument ("--ortho-reg" , "-orth" , dest = "ortho_reg" , type = float , default = 0.1 ,
25
- help = "Ortho-regularization impact coefficient" )
26
-
27
- parser .add_argument ("--epochs" , "-e" , dest = "epochs" , type = int , default = 1 ,
28
- help = "Epochs count" )
29
-
30
- parser .add_argument ("--optimizer" , "-opt" , dest = "optimizer" , type = str , default = "adam" , help = "Optimizer" ,
31
- choices = ["adam" , "sgd" , "asgd" , "adagrad" ])
32
-
33
- parser .add_argument ("--negative-samples" , "-ns" , dest = "neg_samples" , type = int , default = 5 ,
34
- help = "Negative samples per positive one" )
35
-
36
- parser .add_argument ("--dataset-path" , "-d" , dest = "dataset_path" , type = str , default = "reviews_Electronics_5.json.txt" ,
37
- help = "Path to a training texts file. One sentence per line, tokens separated wiht spaces." )
38
-
39
- parser .add_argument ("--maxlen" , "-l" , type = int , default = 201 ,
40
- help = "Max length of the considered sentence; the rest is clipped if longer" )
41
-
42
- args = parser .parse_args ()
43
-
44
- w2v_model = get_w2v (args .wv_path )
45
- wv_dim = w2v_model .vector_size
46
- y = torch .zeros (args .batch_size , 1 )
47
-
48
- model = ABAE (wv_dim = wv_dim ,
49
- asp_count = args .aspects_number ,
50
- init_aspects_matrix = get_centroids (w2v_model , aspects_count = args .aspects_number ))
51
- print (model )
9
+ @hydra .main ("configs" , "config" )
10
+ def main (cfg ):
11
+
12
+ w2v_model = get_w2v (cfg .embeddings .path )
13
+ print (cfg )
14
+ print (w2v_model )
15
+ # wv_dim = w2v_model.vector_size
16
+ # y = torch.zeros(args.batch_size, 1)
17
+ #
18
+ # model = ABAE(wv_dim=wv_dim,
19
+ # asp_count=args.aspects_number,
20
+ # init_aspects_matrix=get_centroids(w2v_model, aspects_count=args.aspects_number))
21
+ # print(model)
22
+ #
23
+ # criterion = torch.nn.MSELoss(reduction="sum")
24
+ #
25
+ # optimizer = None
26
+ # scheduler = None
27
+ #
28
+ # if args.optimizer == "adam":
29
+ # optimizer = torch.optim.Adam(model.parameters())
30
+ # elif args.optimizer == "sgd":
31
+ # optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
32
+ # elif args.optimizer == "adagrad":
33
+ # optimizer = torch.optim.Adagrad(model.parameters())
34
+ # elif args.optimizer == "asgd":
35
+ # optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)
36
+ # else:
37
+ # raise Exception("Optimizer '%s' is not supported" % args.optimizer)
38
+ #
39
+ # for t in range(args.epochs):
40
+ #
41
+ # print("Epoch %d/%d" % (t + 1, args.epochs))
42
+ #
43
+ # data_iterator = read_data_tensors(args.dataset_path, args.wv_path,
44
+ # batch_size=args.batch_size, maxlen=args.maxlen)
45
+ #
46
+ # for item_number, (x, texts) in enumerate(data_iterator):
47
+ # if x.shape[0] < args.batch_size: # pad with 0 if smaller than batch size
48
+ # x = np.pad(x, ((0, args.batch_size - x.shape[0]), (0, 0), (0, 0)))
49
+ #
50
+ # x = torch.from_numpy(x)
51
+ #
52
+ # # extracting bad samples from the very same batch; not sure if this is OK, so todo
53
+ # negative_samples = torch.stack(
54
+ # tuple([x[torch.randperm(x.shape[0])[:args.neg_samples]] for _ in range(args.batch_size)]))
55
+ #
56
+ # # prediction
57
+ # y_pred = model(x, negative_samples)
58
+ #
59
+ # # error computation
60
+ # loss = criterion(y_pred, y)
61
+ # optimizer.zero_grad()
62
+ # loss.backward()
63
+ # optimizer.step()
64
+ #
65
+ # if item_number % 1000 == 0:
66
+ #
67
+ # print(item_number, "batches, and LR:", optimizer.param_groups[0]['lr'])
68
+ #
69
+ # for i, aspect in enumerate(model.get_aspect_words(w2v_model)):
70
+ # print(i + 1, " ".join([a for a in aspect]))
71
+ #
72
+ # print("Loss:", loss.item())
73
+ # print()
52
74
53
- criterion = torch .nn .MSELoss (reduction = "sum" )
54
75
55
- optimizer = None
56
- scheduler = None
57
-
58
- if args .optimizer == "adam" :
59
- optimizer = torch .optim .Adam (model .parameters ())
60
- elif args .optimizer == "sgd" :
61
- optimizer = torch .optim .SGD (model .parameters (), lr = 0.05 )
62
- elif args .optimizer == "adagrad" :
63
- optimizer = torch .optim .Adagrad (model .parameters ())
64
- elif args .optimizer == "asgd" :
65
- optimizer = torch .optim .ASGD (model .parameters (), lr = 0.05 )
66
- else :
67
- raise Exception ("Optimizer '%s' is not supported" % args .optimizer )
68
-
69
- for t in range (args .epochs ):
70
-
71
- print ("Epoch %d/%d" % (t + 1 , args .epochs ))
72
-
73
- data_iterator = read_data_tensors (args .dataset_path , args .wv_path ,
74
- batch_size = args .batch_size , maxlen = args .maxlen )
75
-
76
- for item_number , (x , texts ) in enumerate (data_iterator ):
77
- if x .shape [0 ] < args .batch_size : # pad with 0 if smaller than batch size
78
- x = np .pad (x , ((0 , args .batch_size - x .shape [0 ]), (0 , 0 ), (0 , 0 )))
79
-
80
- x = torch .from_numpy (x )
81
-
82
- # extracting bad samples from the very same batch; not sure if this is OK, so todo
83
- negative_samples = torch .stack (
84
- tuple ([x [torch .randperm (x .shape [0 ])[:args .neg_samples ]] for _ in range (args .batch_size )]))
85
-
86
- # prediction
87
- y_pred = model (x , negative_samples )
88
-
89
- # error computation
90
- loss = criterion (y_pred , y )
91
- optimizer .zero_grad ()
92
- loss .backward ()
93
- optimizer .step ()
94
-
95
- if item_number % 1000 == 0 :
96
-
97
- print (item_number , "batches, and LR:" , optimizer .param_groups [0 ]['lr' ])
98
-
99
- for i , aspect in enumerate (model .get_aspect_words (w2v_model )):
100
- print (i + 1 , " " .join ([a for a in aspect ]))
101
-
102
- print ("Loss:" , loss .item ())
103
- print ()
76
+ if __name__ == "__main__" :
77
+ main ()
0 commit comments