Skip to content

Commit d4876c8

Browse files
committed
testing Nan/ best practice model
1 parent 89d6d28 commit d4876c8

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

main.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __init__(self, model, linear_gcca, outdim_size, epoch_num, batch_size, learn
3737
self.loss = model.loss
3838
self.optimizer = torch.optim.Adam(
3939
self.model.model_list.parameters(), lr=learning_rate, weight_decay=reg_par)
40+
self.scheduler = torch.optim.lr_scheduler.StepLR(
41+
self.optimizer, step_size=200, gamma=0.7)
4042
self.device = device
4143

4244
self.linear_gcca = linear_gcca()
@@ -86,6 +88,7 @@ def fit(self, x_list, vx_list=None, tx_list=None, checkpoint='checkpoint.model')
8688
train_losses.append(loss.item())
8789
loss.backward()
8890
self.optimizer.step()
91+
self.scheduler.step()
8992
train_loss = np.mean(train_losses)
9093

9194
info_string = "Epoch {:d}/{:d} - time: {:.2f} - training_loss: {:.4f}"
@@ -145,7 +148,7 @@ def _get_outputs(self, x_list):
145148
losses = []
146149
outputs_list = []
147150
for batch_idx in batch_idxs:
148-
batch_x = [x[batch_idx, :] for x in x_list]
151+
batch_x = [x[batch_idx, :].to(self.device) for x in x_list]
149152
outputs = self.model(batch_x)
150153
outputs_list.append([o_j.clone().detach() for o_j in outputs])
151154
loss = self.loss(outputs)
@@ -176,15 +179,15 @@ def save(self, name):
176179
outdim_size = 2
177180

178181
# number of layers with nodes in each one
179-
layer_sizes1 = [128, 128, outdim_size]
180-
layer_sizes2 = [128, 128, outdim_size]
181-
layer_sizes3 = [128, 128, outdim_size]
182+
layer_sizes1 = [256, 512, 128, outdim_size]
183+
layer_sizes2 = [256, 512, 128, outdim_size]
184+
layer_sizes3 = [256, 512, 128, outdim_size]
182185
layer_sizes_list = [layer_sizes1, layer_sizes2, layer_sizes3]
183186

184187
# the parameters for training the network
185-
learning_rate = 1e-3
186-
epoch_num = 100
187-
batch_size = 40
188+
learning_rate = 1e-2
189+
epoch_num = 1000
190+
batch_size = 400
188191

189192
# the regularization parameter of the network
190193
# seems necessary to avoid the gradient exploding especially when non-saturating activations are used
@@ -201,8 +204,10 @@ def save(self, name):
201204

202205
N = 400
203206
views = create_synthData(N)
204-
for view in views:
205-
print(view.shape)
207+
print(f'input views shape :')
208+
for i, view in enumerate(views):
209+
print(f'view_{i} : {view.shape}')
210+
view = view.to(device)
206211

207212
# size of the input for view 1 and view 2
208213
input_shape_list = [view.shape[-1] for view in views]

models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ def __init__(self, layer_sizes, input_size):
1919
else:
2020
layers.append(nn.Sequential(
2121
nn.Linear(layer_sizes[l_id], layer_sizes[l_id + 1]),
22-
nn.ReLU(),
23-
# nn.BatchNorm1d(num_features=layer_sizes[l_id + 1], affine=True),
22+
nn.Sigmoid(),
23+
nn.BatchNorm1d(num_features=layer_sizes[l_id + 1], affine=False),
2424
))
2525
self.layers = nn.ModuleList(layers)
2626

0 commit comments

Comments
 (0)