Skip to content

Commit e63256a

Browse files
author
juan coria
committed
Merge branch 'develop' of https://github.com/juanmc2005/repr-learning into develop
2 parents c4aba7f + c79824f commit e63256a

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

datasets/voxceleb.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,22 @@ def __next__(self):
3737
return batch, torch.Tensor(dic['y']).long()
3838

3939

40-
class VoxCeleb1(SimDataset):
41-
sample_rate = 16000
42-
43-
@staticmethod
44-
def config(segment_size_s: float):
45-
return metrics.SpeakerValidationConfig(protocol_name='VoxCeleb.SpeakerVerification.VoxCeleb1_X',
46-
feature_extraction=RawAudio(sample_rate=VoxCeleb1.sample_rate),
47-
preprocessors={'audio': FileFinder()},
48-
duration=segment_size_s)
40+
class VoxCelebDataset(SimDataset):
4941

5042
def __init__(self, batch_size: int, segment_size_millis: int):
43+
self.sample_rate = 16000
5144
self.batch_size = batch_size
5245
self.segment_size_s = segment_size_millis / 1000
53-
self.nfeat = VoxCeleb1.sample_rate * segment_size_millis // 1000
46+
self.nfeat = self.sample_rate * segment_size_millis // 1000
5447
self.config = VoxCeleb1.config(self.segment_size_s)
5548
self.protocol = get_protocol(self.config.protocol_name, preprocessors=self.config.preprocessors)
5649
self.train_gen, self.dev_gen, self.test_gen = None, None, None
5750
print(f"[Segment Size: {self.segment_size_s}s]")
5851
print(f"[Embedding Size: {self.nfeat}]")
5952

53+
def _create_config(self, segment_size_sec: float):
54+
raise NotImplementedError
55+
6056
def training_partition(self) -> VoxCelebPartition:
6157
if self.train_gen is None:
6258
self.train_gen = SpeechSegmentGenerator(self.config.feature_extraction, self.protocol,
@@ -77,3 +73,21 @@ def test_partition(self) -> VoxCelebPartition:
7773
subset='test', per_label=1, per_fold=self.batch_size,
7874
duration=self.segment_size_s, parallel=2)
7975
return VoxCelebPartition(self.test_gen, self.nfeat)
76+
77+
78+
class VoxCeleb1(VoxCelebDataset):
79+
80+
def _create_config(self, segment_size_sec: float):
81+
return metrics.SpeakerValidationConfig(protocol_name='VoxCeleb.SpeakerVerification.VoxCeleb1_X',
82+
feature_extraction=RawAudio(sample_rate=self.sample_rate),
83+
preprocessors={'audio': FileFinder()},
84+
duration=segment_size_sec)
85+
86+
87+
class VoxCeleb2(VoxCelebDataset)
88+
89+
def _create_config(self, segment_size_sec: float):
90+
return metrics.SpeakerValidationConfig(protocol_name='VoxCeleb.SpeakerVerification.VoxCeleb2',
91+
feature_extraction=RawAudio(sample_rate=self.sample_rate),
92+
preprocessors={'audio': FileFinder()},
93+
duration=segment_size_sec)

losses/config.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,19 @@ def __init__(self, device, nfeat, nclass, lweight=1, distance=EuclideanDistance(
100100
super(CenterConfig, self).__init__('Center Loss', f"λ={lweight} - {distance}", loss_module, self.loss, distance)
101101

102102
def optimizer(self, model, task, lr):
103-
# TODO change optimizer according to task
104-
# Was using lr0=0.001 and lr1=0.5
105-
optimizers = [optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005),
106-
optim.SGD(self.loss.center_parameters(), lr=lr)]
107-
schedulers = [lr_scheduler.StepLR(optimizers[0], 20, gamma=0.8)]
103+
if task == 'mnist':
104+
# Was using lr0=0.001 and lr1=0.5
105+
optimizers = [optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005),
106+
optim.SGD(self.loss.center_parameters(), lr=lr)]
107+
schedulers = [lr_scheduler.StepLR(optimizers[0], 20, gamma=0.8)]
108+
elif task == 'speaker':
109+
optimizers = sincnet_optims(model, lr)
110+
optimizers.append(optim.SGD(self.loss.center_parameters(), lr=lr))
111+
schedulers = []
112+
elif task == 'sts':
113+
optimizers, schedulers = [optim.RMSprop(model.parameters(), lr=lr)], []
114+
else:
115+
raise ValueError('Task must be one of mnist/speaker/sts')
108116
return base.Optimizer(optimizers, schedulers)
109117

110118

@@ -116,12 +124,20 @@ def __init__(self, device, nfeat, nclass, alpha=6.25):
116124
super(CocoConfig, self).__init__('CoCo Loss', f"α={alpha}", loss_module, loss, CosineDistance())
117125

118126
def optimizer(self, model, task, lr):
119-
# TODO change optimizer according to task
120-
# Was using lr0=0.001 and lr1=0.01
121-
params = model.all_params()
122-
optimizers = [optim.SGD(params[0], lr=lr, momentum=0.9, weight_decay=0.0005),
123-
optim.SGD(params[1], lr=lr, momentum=0.9)]
124-
schedulers = [lr_scheduler.StepLR(optimizers[0], 10, gamma=0.5)]
127+
if task == 'mnist':
128+
# Was using lr0=0.001 and lr1=0.01
129+
params = model.all_params()
130+
optimizers = [optim.SGD(params[0], lr=lr, momentum=0.9, weight_decay=0.0005),
131+
optim.SGD(params[1], lr=lr, momentum=0.9)]
132+
schedulers = [lr_scheduler.StepLR(optimizers[0], 10, gamma=0.5)]
133+
elif task == 'speaker':
134+
optimizers = sincnet_optims(model, lr)
135+
optimizers.append(optim.SGD(self.loss_module.parameters(), lr=lr))
136+
schedulers = []
137+
elif task == 'sts':
138+
optimizers, schedulers = [optim.RMSprop(model.parameters(), lr=lr)], []
139+
else:
140+
raise ValueError('Task must be one of mnist/speaker/sts')
125141
return base.Optimizer(optimizers, schedulers)
126142

127143

0 commit comments

Comments
 (0)