@@ -37,6 +37,8 @@ def __init__(self, model, linear_gcca, outdim_size, epoch_num, batch_size, learn
37
37
self .loss = model .loss
38
38
self .optimizer = torch .optim .Adam (
39
39
self .model .model_list .parameters (), lr = learning_rate , weight_decay = reg_par )
40
+ self .scheduler = torch .optim .lr_scheduler .StepLR (
41
+ self .optimizer , step_size = 200 , gamma = 0.7 )
40
42
self .device = device
41
43
42
44
self .linear_gcca = linear_gcca ()
@@ -86,6 +88,7 @@ def fit(self, x_list, vx_list=None, tx_list=None, checkpoint='checkpoint.model')
86
88
train_losses .append (loss .item ())
87
89
loss .backward ()
88
90
self .optimizer .step ()
91
+ self .scheduler .step ()
89
92
train_loss = np .mean (train_losses )
90
93
91
94
info_string = "Epoch {:d}/{:d} - time: {:.2f} - training_loss: {:.4f}"
@@ -145,7 +148,7 @@ def _get_outputs(self, x_list):
145
148
losses = []
146
149
outputs_list = []
147
150
for batch_idx in batch_idxs :
148
- batch_x = [x [batch_idx , :] for x in x_list ]
151
+ batch_x = [x [batch_idx , :]. to ( self . device ) for x in x_list ]
149
152
outputs = self .model (batch_x )
150
153
outputs_list .append ([o_j .clone ().detach () for o_j in outputs ])
151
154
loss = self .loss (outputs )
@@ -176,15 +179,15 @@ def save(self, name):
176
179
outdim_size = 2
177
180
178
181
# number of layers with nodes in each one
179
- layer_sizes1 = [128 , 128 , outdim_size ]
180
- layer_sizes2 = [128 , 128 , outdim_size ]
181
- layer_sizes3 = [128 , 128 , outdim_size ]
182
+ layer_sizes1 = [256 , 512 , 128 , outdim_size ]
183
+ layer_sizes2 = [256 , 512 , 128 , outdim_size ]
184
+ layer_sizes3 = [256 , 512 , 128 , outdim_size ]
182
185
layer_sizes_list = [layer_sizes1 , layer_sizes2 , layer_sizes3 ]
183
186
184
187
# the parameters for training the network
185
- learning_rate = 1e-3
186
- epoch_num = 100
187
- batch_size = 40
188
+ learning_rate = 1e-2
189
+ epoch_num = 1000
190
+ batch_size = 400
188
191
189
192
# the regularization parameter of the network
190
193
# seems necessary to avoid the gradient exploding especially when non-saturating activations are used
@@ -201,8 +204,10 @@ def save(self, name):
201
204
202
205
N = 400
203
206
views = create_synthData (N )
204
- for view in views :
205
- print (view .shape )
207
+ print (f'input views shape :' )
208
+ for i , view in enumerate (views ):
209
+ print (f'view_{ i } : { view .shape } ' )
210
+ view = view .to (device )
206
211
207
212
# size of the input for view 1 and view 2
208
213
input_shape_list = [view .shape [- 1 ] for view in views ]
0 commit comments