1
+ import numpy as np
2
+ import os
3
+
4
+ import keras .callbacks as callbacks
5
+ from keras .callbacks import Callback
6
+
7
+ class SnapshotModelCheckpoint (Callback ):
8
+ """Callback that saves the snapshot weights of the model.
9
+ Saves the model weights on certain epochs (which can be considered the
10
+ snapshot of the model at that epoch).
11
+ Should be used with the cosine annealing learning rate schedule to save
12
+ the weight just before learning rate is sharply increased.
13
+ # Arguments:
14
+ nb_epochs: total number of epochs that the model will be trained for.
15
+ nb_snapshots: number of times the weights of the model will be saved.
16
+ fn_prefix: prefix for the filename of the weights.
17
+ """
18
+
19
+ def __init__ (self , nb_epochs , nb_snapshots , fn_prefix = 'Model' ):
20
+ super (SnapshotModelCheckpoint , self ).__init__ ()
21
+
22
+ self .check = nb_epochs // nb_snapshots
23
+ self .fn_prefix = fn_prefix
24
+
25
+ def on_epoch_end (self , epoch , logs = {}):
26
+ if epoch != 0 and (epoch + 1 ) % self .check == 0 :
27
+ filepath = self .fn_prefix + "-%d.h5" % ((epoch + 1 ) // self .check )
28
+ self .model .save_weights (filepath , overwrite = True )
29
+ #print("Saved snapshot at weights/%s_%d.h5" % (self.fn_prefix, epoch))
30
+
31
+
32
+ class SnapshotCallbackBuilder :
33
+ """Callback builder for snapshot ensemble training of a model.
34
+ Creates a list of callbacks, which are provided when training a model
35
+ so as to save the model weights at certain epochs, and then sharply
36
+ increase the learning rate.
37
+ """
38
+
39
+ def __init__ (self , nb_epochs , nb_snapshots , init_lr = 0.1 ):
40
+ """
41
+ Initialize a snapshot callback builder.
42
+ # Arguments:
43
+ nb_epochs: total number of epochs that the model will be trained for.
44
+ nb_snapshots: number of times the weights of the model will be saved.
45
+ init_lr: initial learning rate
46
+ """
47
+ self .T = nb_epochs
48
+ self .M = nb_snapshots
49
+ self .alpha_zero = init_lr
50
+
51
+ def get_callbacks (self , model_prefix = 'Model' ):
52
+ """
53
+ Creates a list of callbacks that can be used during training to create a
54
+ snapshot ensemble of the model.
55
+ Args:
56
+ model_prefix: prefix for the filename of the weights.
57
+ Returns: list of 3 callbacks [ModelCheckpoint, LearningRateScheduler,
58
+ SnapshotModelCheckpoint] which can be provided to the 'fit' function
59
+ """
60
+ if not os .path .exists ('weights/' ):
61
+ os .makedirs ('weights/' )
62
+
63
+ callback_list = [callbacks .ModelCheckpoint ("weights/%s-Best.h5" % model_prefix , monitor = "val_acc" ,
64
+ save_best_only = True , save_weights_only = True ),
65
+ callbacks .LearningRateScheduler (schedule = self ._cosine_anneal_schedule ),
66
+ SnapshotModelCheckpoint (self .T , self .M , fn_prefix = 'weights/%s' % model_prefix )]
67
+
68
+ return callback_list
69
+
70
+ def _cosine_anneal_schedule (self , t ):
71
+ cos_inner = np .pi * (t % (self .T // self .M )) # t - 1 is used when t has 1-based indexing.
72
+ cos_inner /= self .T // self .M
73
+ cos_out = np .cos (cos_inner ) + 1
74
+ return float (self .alpha_zero / 2 * cos_out )
0 commit comments