Skip to content

Commit 810dd7c

Browse files
author
juan coria
committed
Added automatic train logs plotting after training is finished or user interrupts. Added sts augmentation as script parameter. Other improvements and bugfixes
1 parent e63256a commit 810dd7c

11 files changed

+173
-66
lines changed

common.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import torch
66
import numpy as np
77
import random
8+
89
import losses.config as cf
10+
from losses.triplet import SemiHardNegative
911
from distances import CosineDistance, EuclideanDistance
1012

1113

@@ -86,9 +88,10 @@ def get_config(loss: str, nfeat: int, nclass: int, task: str, margin: float) ->
8688
print(f"[Margin: {margin}]")
8789
return cf.TripletConfig(DEVICE,
8890
margin=margin,
89-
distance=EuclideanDistance(),
91+
distance=CosineDistance(),
9092
size_average=False,
91-
online=task != 'sts')
93+
online=task != 'sts',
94+
sampling=SemiHardNegative(margin, deviation=0.02))
9295
elif loss == 'arcface':
9396
print(f"[Margin: {margin}]")
9497
return cf.ArcFaceConfig(DEVICE, nfeat, nclass, margin=margin)
@@ -134,3 +137,22 @@ def dump_params(filepath: str, args):
134137
with open(filepath, 'w') as out:
135138
for k, v in sorted(vars(args).items()):
136139
out.write(f"{k}={v}\n")
140+
141+
142+
def get_basic_plots(lr: float, batch_size: int, eval_metric: str, eval_metric_color: str) -> list:
143+
return [
144+
{
145+
'log_file': 'loss.log',
146+
'metric': 'Loss',
147+
'color': 'blue',
148+
'title': f'Train Loss - lr={lr} - batch_size={batch_size}',
149+
'filename': 'loss-plot'
150+
},
151+
{
152+
'log_file': 'metric.log',
153+
'metric': eval_metric,
154+
'color': eval_metric_color,
155+
'title': f'Dev {eval_metric} - lr={lr} - batch_size={batch_size}',
156+
'filename': f"dev-{eval_metric.lower().replace(' ', '-')}-plot"
157+
}
158+
]

core/base.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datasets.base import SimDatasetPartition
77
from core.optim import Optimizer
88
import common
9+
import visual_utils as vis
910

1011

1112
class TrainingListener:
@@ -64,23 +65,42 @@ def __init__(self, loss_name: str, model: SimNet, loss_fn: nn.Module, partition:
6465
def _restore(self):
6566
if self.model_loader is not None:
6667
checkpoint = self.model_loader.restore(self.model, self.loss_fn, self.optim, self.loss_name)
67-
epoch = checkpoint['epoch']
68-
return checkpoint, epoch + 1
68+
return checkpoint
6969
else:
70-
return None, 1
71-
72-
def train(self, epochs):
73-
checkpoint, epoch = self._restore()
70+
return None
71+
72+
def _create_plots(self, exp_path: str, plots: list):
73+
print("Creating training plots before exiting...")
74+
for plot in plots:
75+
vis.visualize_logs(exp_path,
76+
log_file_name=plot['log_file'],
77+
metric_name=plot['metric'],
78+
color=plot['color'],
79+
title=plot['title'],
80+
plot_file_name=plot['filename'])
81+
print("Done")
82+
83+
def _start_training(self, epochs):
84+
checkpoint = self._restore()
7485

7586
for cb in self.callbacks:
7687
cb.on_before_train(checkpoint)
7788

78-
for i in range(epoch, epoch+epochs):
89+
for i in range(1, epochs + 1):
7990
self.train_epoch(i)
8091

8192
for cb in self.callbacks:
8293
cb.on_after_train()
8394

95+
def train(self, epochs: int, exp_path: str, plots: list):
96+
try:
97+
self._start_training(epochs)
98+
print("Training finished")
99+
self._create_plots(exp_path, plots)
100+
except KeyboardInterrupt:
101+
print("Stopped by user")
102+
self._create_plots(exp_path, plots)
103+
84104
def train_epoch(self, epoch):
85105
self.model.train()
86106

datasets/semeval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from datasets.base import SimDataset, SimDatasetPartition
10+
from models import SimNet
1011
from sts.augmentation import SemEvalAugmentationStrategy, pad_sent_pair
1112
from sts import utils as sts
1213

datasets/voxceleb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, batch_size: int, segment_size_millis: int):
4444
self.batch_size = batch_size
4545
self.segment_size_s = segment_size_millis / 1000
4646
self.nfeat = self.sample_rate * segment_size_millis // 1000
47-
self.config = VoxCeleb1.config(self.segment_size_s)
47+
self.config = self._create_config(self.segment_size_s)
4848
self.protocol = get_protocol(self.config.protocol_name, preprocessors=self.config.preprocessors)
4949
self.train_gen, self.dev_gen, self.test_gen = None, None, None
5050
print(f"[Segment Size: {self.segment_size_s}s]")
@@ -84,7 +84,7 @@ def _create_config(self, segment_size_sec: float):
8484
duration=segment_size_sec)
8585

8686

87-
class VoxCeleb2(VoxCelebDataset)
87+
class VoxCeleb2(VoxCelebDataset):
8888

8989
def _create_config(self, segment_size_sec: float):
9090
return metrics.SpeakerValidationConfig(protocol_name='VoxCeleb.SpeakerVerification.VoxCeleb2',

distances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def to_sklearn_metric(self):
9292
return 'euclidean'
9393

9494
def dist(self, x, y):
95-
return torch.dist(x, y, p=2)
95+
return torch.sum(torch.pow((x - y), 2), dim=1)
9696

9797
def sqdist_sum(self, x, y):
9898
return (x - y).pow(2).sum()

losses/triplet.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
import torch
34
import torch.nn as nn
45
import torch.nn.functional as F
56
import numpy as np
@@ -19,6 +20,9 @@ def triplets(self, y, distances):
1920
"""
2021
raise NotImplementedError("a TripletSamplingStrategy should implement 'triplets'")
2122

23+
def filter(self, dist_pos, dist_neg):
24+
return dist_pos, dist_neg
25+
2226

2327
class BatchAll(TripletSamplingStrategy):
2428
"""
@@ -41,6 +45,20 @@ def triplets(self, y, distances):
4145
return anchors, positives, negatives
4246

4347

48+
class SemiHardNegative(TripletSamplingStrategy):
49+
50+
def __init__(self, m: float, deviation: float):
51+
self.m = m
52+
self.deviation = deviation
53+
54+
def filter(self, dist_pos, dist_neg):
55+
keep_inds = []
56+
for i in range(dist_neg.size(0)):
57+
keep_inds.append(self.m >= dist_neg[i] or self.deviation >= dist_neg[i] - self.m)
58+
keep_inds = torch.Tensor(keep_inds).float()
59+
return keep_inds
60+
61+
4462
class HardestNegative(TripletSamplingStrategy):
4563
"""
4664
Hardest negative strategy.
@@ -100,7 +118,7 @@ class TripletLoss(nn.Module):
100118
:param device: a device in which to run the computation
101119
:param margin: a margin value to separe classes
102120
:param distance: a distance object to measure between the samples
103-
:param strategy: a TripletSamplingStrategy
121+
:param sampling: a TripletSamplingStrategy
104122
"""
105123

106124
def __init__(self, device: str, margin: float, distance: Distance,
@@ -153,7 +171,10 @@ def forward(self, feat, logits, y):
153171
# Calculate the distances to positives and negatives for each anchor
154172
dpos = self.distance.dist(anchors, positives)
155173
dneg = self.distance.dist(anchors, negatives)
174+
# keep_mask = self.sampling.filter(dpos, dneg).to(self.device)
175+
# dpos = keep_mask * dpos
176+
# dneg = keep_mask * dneg
156177

157178
# Calculate the loss using the margin
158-
loss = F.relu(dpos - dneg + self.margin)
179+
loss = F.relu(torch.pow(dpos, 2) - torch.pow(dneg, 2) + self.margin)
159180
return loss.mean() if self.size_average else loss.sum()

metrics.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ def eval(self, model, partition: str = 'development'):
195195
# Returning 1-eer because the evaluator keeps track of the highest metric value
196196
return 1 - eer, y_pred, y_true
197197

198-
def on_before_train(self, checkpoint):
199-
if checkpoint is not None:
200-
self.best_metric = checkpoint['accuracy']
201-
202198
def on_after_epoch(self, epoch, model, loss_fn, optim):
203199
if epoch % self.eval_interval == 0:
204200
metric_value, dists, y_true = self.eval(model.to_prediction_model(), self.partition)
@@ -259,10 +255,6 @@ def _eval(self, model):
259255
feat_test, y_test = np.concatenate(feat_test), np.concatenate(y_test)
260256
return feat_test, y_test
261257

262-
def on_before_train(self, checkpoint):
263-
if checkpoint is not None:
264-
self.best_metric = checkpoint['accuracy']
265-
266258
def on_before_epoch(self, epoch):
267259
self.feat_train, self.y_train = [], []
268260

@@ -339,10 +331,6 @@ def eval(self, model):
339331
y_test = np.concatenate(y_test)
340332
return phrases, feat_test, y_test
341333

342-
def on_before_train(self, checkpoint):
343-
if checkpoint is not None:
344-
self.best_metric = checkpoint['accuracy']
345-
346334
def on_after_epoch(self, epoch, model, loss_fn, optim):
347335
_, feat_test, y_test = self.eval(model.to_prediction_model())
348336
metric_value = self.metric.get()
@@ -409,10 +397,6 @@ def eval(self, model):
409397
feat_test, y_test = np.concatenate(feat_test), np.concatenate(y_test)
410398
return phrases, feat_test, y_test
411399

412-
def on_before_train(self, checkpoint):
413-
if checkpoint is not None:
414-
self.best_metric = checkpoint['accuracy']
415-
416400
def on_after_epoch(self, epoch, model, loss_fn, optim):
417401
phrases, feat_test, y_test = self.eval(model)
418402
metric_value = self.metric.get()

0 commit comments

Comments
 (0)