@@ -100,11 +100,19 @@ def __init__(self, device, nfeat, nclass, lweight=1, distance=EuclideanDistance(
100
100
super (CenterConfig , self ).__init__ ('Center Loss' , f"λ={ lweight } - { distance } " , loss_module , self .loss , distance )
101
101
102
102
def optimizer (self , model , task , lr ):
103
- # TODO change optimizer according to task
104
- # Was using lr0=0.001 and lr1=0.5
105
- optimizers = [optim .SGD (model .parameters (), lr = lr , momentum = 0.9 , weight_decay = 0.0005 ),
106
- optim .SGD (self .loss .center_parameters (), lr = lr )]
107
- schedulers = [lr_scheduler .StepLR (optimizers [0 ], 20 , gamma = 0.8 )]
103
+ if task == 'mnist' :
104
+ # Was using lr0=0.001 and lr1=0.5
105
+ optimizers = [optim .SGD (model .parameters (), lr = lr , momentum = 0.9 , weight_decay = 0.0005 ),
106
+ optim .SGD (self .loss .center_parameters (), lr = lr )]
107
+ schedulers = [lr_scheduler .StepLR (optimizers [0 ], 20 , gamma = 0.8 )]
108
+ elif task == 'speaker' :
109
+ optimizers = sincnet_optims (model , lr )
110
+ optimizers .append (optim .SGD (self .loss .center_parameters (), lr = lr ))
111
+ schedulers = []
112
+ elif task == 'sts' :
113
+ optimizers , schedulers = [optim .RMSprop (model .parameters (), lr = lr )], []
114
+ else :
115
+ raise ValueError ('Task must be one of mnist/speaker/sts' )
108
116
return base .Optimizer (optimizers , schedulers )
109
117
110
118
@@ -116,12 +124,20 @@ def __init__(self, device, nfeat, nclass, alpha=6.25):
116
124
super (CocoConfig , self ).__init__ ('CoCo Loss' , f"α={ alpha } " , loss_module , loss , CosineDistance ())
117
125
118
126
def optimizer (self , model , task , lr ):
119
- # TODO change optimizer according to task
120
- # Was using lr0=0.001 and lr1=0.01
121
- params = model .all_params ()
122
- optimizers = [optim .SGD (params [0 ], lr = lr , momentum = 0.9 , weight_decay = 0.0005 ),
123
- optim .SGD (params [1 ], lr = lr , momentum = 0.9 )]
124
- schedulers = [lr_scheduler .StepLR (optimizers [0 ], 10 , gamma = 0.5 )]
127
+ if task == 'mnist' :
128
+ # Was using lr0=0.001 and lr1=0.01
129
+ params = model .all_params ()
130
+ optimizers = [optim .SGD (params [0 ], lr = lr , momentum = 0.9 , weight_decay = 0.0005 ),
131
+ optim .SGD (params [1 ], lr = lr , momentum = 0.9 )]
132
+ schedulers = [lr_scheduler .StepLR (optimizers [0 ], 10 , gamma = 0.5 )]
133
+ elif task == 'speaker' :
134
+ optimizers = sincnet_optims (model , lr )
135
+ optimizers .append (optim .SGD (self .loss_module .parameters (), lr = lr ))
136
+ schedulers = []
137
+ elif task == 'sts' :
138
+ optimizers , schedulers = [optim .RMSprop (model .parameters (), lr = lr )], []
139
+ else :
140
+ raise ValueError ('Task must be one of mnist/speaker/sts' )
125
141
return base .Optimizer (optimizers , schedulers )
126
142
127
143
0 commit comments