3
3
import torch
4
4
from torch .optim import Optimizer
5
5
from math import pi , cos
6
-
7
-
8
- def flat_and_anneal_lr_scheduler (optimizer , total_iters , \
9
- warmup_iters = 0 , warmup_factor = 0.1 , warmup_method = 'linear' , \
10
- anneal_point = 0.72 , anneal_method = 'cosine' , target_lr_factor = 0 , \
11
- poly_power = 1.0 , step_gamma = 0.1 , steps = [0.5 , 0.75 ], \
6
+ import warnings
7
+
8
+
9
+ def flat_and_anneal_lr_scheduler (
10
+ optimizer ,
11
+ total_iters ,
12
+ warmup_iters = 0 ,
13
+ warmup_factor = 0.1 ,
14
+ warmup_method = "linear" ,
15
+ anneal_point = 0.72 ,
16
+ anneal_method = "cosine" ,
17
+ target_lr_factor = 0 ,
18
+ poly_power = 1.0 ,
19
+ step_gamma = 0.1 ,
20
+ steps = [2 / 3.0 , 8 / 9.0 ],
21
+ return_function = False ,
12
22
):
13
- """https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py
23
+ """Ref: https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py.
24
+
14
25
warmup_initial_lr = warmup_factor * base_lr
15
26
target_lr = base_lr * target_lr_factor
27
+ total_iters: cycle length; set to max_iter to get a one cycle schedule.
16
28
"""
17
29
if warmup_method not in ("constant" , "linear" ):
18
- raise ValueError ("Only 'constant' or 'linear' warmup_method accepted,"
19
- "got {}" .format (warmup_method ))
30
+ raise ValueError ("Only 'constant' or 'linear' warmup_method accepted," "got {}" .format (warmup_method ))
20
31
21
- if anneal_method not in ("cosine" , "linear" , "poly" , "exp" , "step" ):
22
- raise ValueError ("Only 'cosine', 'linear', 'poly', 'exp' or 'step' anneal_method accepted,"
23
- "got {}" .format (anneal_method ))
32
+ if anneal_method not in ("cosine" , "linear" , "poly" , "exp" , "step" , "none" ):
33
+ raise ValueError (
34
+ "Only 'cosine', 'linear', 'poly', 'exp', 'step' or 'none' anneal_method accepted,"
35
+ "got {}" .format (anneal_method )
36
+ )
24
37
25
- if anneal_method == ' step' :
38
+ if anneal_method == " step" :
26
39
if any ([_step < warmup_iters / total_iters or _step > 1 for _step in steps ]):
27
- raise ValueError ("error in steps: {}. warmup_iters: {} total_iters: {}."
28
- "steps should be in ({},1)" .format (steps , warmup_iters , total_iters , \
29
- warmup_iters / total_iters ))
30
- if steps != sorted (steps ):
31
- raise ValueError ("steps {} is not in ascending order." )
32
- print ("ignore anneal_point when using step anneal_method" )
40
+ raise ValueError (
41
+ "error in steps: {}. warmup_iters: {} total_iters: {}."
42
+ "steps should be in ({},1)" .format (steps , warmup_iters , total_iters , warmup_iters / total_iters )
43
+ )
44
+ if list (steps ) != sorted (steps ):
45
+ raise ValueError ("steps {} is not in ascending order." .format (steps ))
46
+ warnings .warn ("ignore anneal_point when using step anneal_method" )
33
47
anneal_start = steps [0 ] * total_iters
34
48
else :
35
49
if anneal_point > 1 or anneal_point < 0 :
@@ -38,91 +52,108 @@ def flat_and_anneal_lr_scheduler(optimizer, total_iters, \
38
52
39
53
def f (x ): # x is the iter in lr scheduler, return the lr_factor
40
54
# the final lr is warmup_factor * base_lr
55
+ x = x % total_iters # cyclic
41
56
if x < warmup_iters :
42
- if warmup_method == ' linear' :
57
+ if warmup_method == " linear" :
43
58
alpha = float (x ) / warmup_iters
44
59
return warmup_factor * (1 - alpha ) + alpha
45
- elif warmup_method == ' constant' :
60
+ elif warmup_method == " constant" :
46
61
return warmup_factor
47
62
elif x >= anneal_start :
48
- if anneal_method == ' step' :
63
+ if anneal_method == " step" :
49
64
# ignore anneal_point and target_lr_factor
50
65
milestones = [_step * total_iters for _step in steps ]
51
- lr_factor = step_gamma ** bisect_right (milestones , float (x ))
52
- elif anneal_method == ' cosine' :
66
+ lr_factor = step_gamma ** bisect_right (milestones , float (x ))
67
+ elif anneal_method == " cosine" :
53
68
# slow --> fast --> slow
54
- lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor ) * \
55
- (1 + cos (pi * ((float (x ) - anneal_start ) / (total_iters - anneal_start ))))
56
- elif anneal_method == 'linear' :
69
+ lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor ) * (
70
+ 1 + cos (pi * ((float (x ) - anneal_start ) / (total_iters - anneal_start )))
71
+ )
72
+ elif anneal_method == "linear" :
57
73
# (y-m) / (B-x) = (1-m) / (B-A)
58
- lr_factor = target_lr_factor + (1 - target_lr_factor ) * \
59
- (total_iters - float (x )) / (total_iters - anneal_start )
60
- elif anneal_method == 'poly' :
74
+ lr_factor = target_lr_factor + (1 - target_lr_factor ) * (total_iters - float (x )) / (
75
+ total_iters - anneal_start
76
+ )
77
+ elif anneal_method == "poly" :
61
78
# slow --> fast if poly_power < 1
62
79
# fast --> slow if poly_power > 1
63
80
# when poly_power == 1.0, it is the same with linear
64
- lr_factor = target_lr_factor + (1 - target_lr_factor ) * \
65
- ((total_iters - float (x )) / (total_iters - anneal_start )) ** poly_power
66
- elif anneal_method == 'exp' :
81
+ lr_factor = (
82
+ target_lr_factor
83
+ + (1 - target_lr_factor ) * ((total_iters - float (x )) / (total_iters - anneal_start )) ** poly_power
84
+ )
85
+ elif anneal_method == "exp" :
67
86
# fast --> slow
68
87
# do not decay too much, especially if lr_end == 0, lr will be
69
88
# 0 at anneal iter, so we should avoid that
70
89
_target_lr_factor = max (target_lr_factor , 5e-3 )
71
- lr_factor = _target_lr_factor ** ( \
72
- (float (x ) - anneal_start ) / (total_iters - anneal_start ))
90
+ lr_factor = _target_lr_factor ** ((float (x ) - anneal_start ) / (total_iters - anneal_start ))
73
91
else :
74
92
lr_factor = 1
75
93
return lr_factor
76
94
else : # warmup_iter <= x < anneal_start_iter
77
95
return 1
78
96
79
- return torch .optim .lr_scheduler .LambdaLR (optimizer , f )
97
+ if return_function :
98
+ return torch .optim .lr_scheduler .LambdaLR (optimizer , f ), f
99
+ else :
100
+ return torch .optim .lr_scheduler .LambdaLR (optimizer , f )
80
101
81
102
82
103
def test_flat_and_anneal ():
83
104
from mmcv import Config
84
105
import numpy as np
85
- model = resnet18 ()
86
106
87
- optimizer_cfg = dict (type = 'Adam' , lr = 1e-4 , weight_decay = 0 )
107
+ model = resnet18 ()
108
+ base_lr = 1e-4
109
+ optimizer_cfg = dict (type = "Adam" , lr = base_lr , weight_decay = 0 )
88
110
optimizer = obj_from_dict (optimizer_cfg , torch .optim , dict (params = model .parameters ()))
89
111
90
112
# learning policy
91
113
total_epochs = 80
92
114
epoch_len = 500
93
- total_iters = epoch_len * total_epochs
115
+ total_iters = epoch_len * total_epochs // 2
94
116
# poly, step, linear, exp, cosine
95
117
lr_cfg = Config (
96
118
dict (
97
- anneal_method = 'cosine' ,
98
- warmup_method = 'linear' ,
119
+ # anneal_method="cosine",
120
+ # anneal_method="linear",
121
+ # anneal_method="poly",
122
+ # anneal_method="exp",
123
+ anneal_method = "step" ,
124
+ warmup_method = "linear" ,
99
125
step_gamma = 0.1 ,
100
126
warmup_factor = 0.1 ,
101
127
warmup_iters = 800 ,
102
128
poly_power = 5 ,
103
- target_lr_factor = 0. ,
129
+ target_lr_factor = 0.0 ,
104
130
steps = [0.5 , 0.75 , 0.9 ],
105
131
anneal_point = 0.72 ,
106
- ))
132
+ )
133
+ )
107
134
108
135
# scheduler = build_scheduler(lr_config, optimizer, epoch_length)
109
136
scheduler = flat_and_anneal_lr_scheduler (
110
- optimizer = optimizer , total_iters = total_iters , \
111
- warmup_method = lr_cfg .warmup_method , warmup_factor = lr_cfg .warmup_factor , \
112
- warmup_iters = lr_cfg .warmup_iters , \
113
- anneal_method = lr_cfg .anneal_method , anneal_point = lr_cfg .anneal_point , \
114
- target_lr_factor = lr_cfg .target_lr_factor , \
115
- poly_power = lr_cfg .poly_power , \
116
- step_gamma = lr_cfg .step_gamma , steps = lr_cfg .steps , \
137
+ optimizer = optimizer ,
138
+ total_iters = total_iters ,
139
+ warmup_method = lr_cfg .warmup_method ,
140
+ warmup_factor = lr_cfg .warmup_factor ,
141
+ warmup_iters = lr_cfg .warmup_iters ,
142
+ anneal_method = lr_cfg .anneal_method ,
143
+ anneal_point = lr_cfg .anneal_point ,
144
+ target_lr_factor = lr_cfg .target_lr_factor ,
145
+ poly_power = lr_cfg .poly_power ,
146
+ step_gamma = lr_cfg .step_gamma ,
147
+ steps = lr_cfg .steps ,
117
148
)
118
- print (' start lr: {}' .format (scheduler .get_lr ()))
149
+ print (" start lr: {}" .format (scheduler .get_lr ()))
119
150
steps = []
120
151
lrs = []
121
152
122
153
epoch_lrs = []
123
154
global_step = 0
124
155
125
- start_epoch = 20
156
+ start_epoch = 0
126
157
for epoch in range (start_epoch ):
127
158
for batch in range (epoch_len ):
128
159
scheduler .step () # when no state_dict availble
@@ -133,41 +164,38 @@ def test_flat_and_anneal():
133
164
# scheduler.step(epoch)
134
165
# print(type(scheduler.get_lr()[0]))
135
166
# import pdb;pdb.set_trace()
136
- epoch_lrs .append ([epoch ,
137
- scheduler .get_lr ()[0 ]]) # only get the first lr (maybe a group of lrs)
167
+ epoch_lrs .append ([epoch , scheduler .get_lr ()[0 ]]) # only get the first lr (maybe a group of lrs)
138
168
for batch in range (epoch_len ):
139
169
# if global_step < lr_config['warmup_iters']:
140
170
# scheduler.step(global_step)
141
171
cur_lr = scheduler .get_lr ()[0 ]
142
172
if global_step == 0 or (len (lrs ) >= 1 and cur_lr != lrs [- 1 ]):
143
- print ('epoch {}, batch: {}, global_step:{} lr: {}' .format (
144
- epoch , batch , global_step , cur_lr ))
173
+ print ("epoch {}, batch: {}, global_step:{} lr: {}" .format (epoch , batch , global_step , cur_lr ))
145
174
steps .append (global_step )
146
175
lrs .append (cur_lr )
147
176
global_step += 1
148
177
scheduler .step () # usually after optimizer.step()
149
178
# print(epoch_lrs)
150
179
# import pdb;pdb.set_trace()
151
- epoch_lrs .append ([total_epochs , scheduler .get_lr ()[0 ]])
180
+ # epoch_lrs.append([total_epochs, scheduler.get_lr()[0]])
152
181
153
182
epoch_lrs = np .asarray (epoch_lrs , dtype = np .float32 )
154
183
for i in range (len (epoch_lrs )):
155
- print (' {:02d} {}' .format (int (epoch_lrs [i ][0 ]), epoch_lrs [i ][1 ]))
184
+ print (" {:02d} {}" .format (int (epoch_lrs [i ][0 ]), epoch_lrs [i ][1 ]))
156
185
157
- plt .figure (dpi = 200 )
158
- plt .suptitle ('{}' .format (dict (lr_cfg )), size = 4 )
186
+ plt .figure (dpi = 100 )
187
+ plt .suptitle ("{}" .format (dict (lr_cfg )), size = 4 )
159
188
plt .subplot (1 , 2 , 1 )
160
- plt .plot (steps , lrs )
189
+ plt .plot (steps , lrs , "-." )
161
190
# plt.show()
162
191
plt .subplot (1 , 2 , 2 )
163
192
# print(epoch_lrs.dtype)
164
- plt .plot (epoch_lrs [:, 0 ], epoch_lrs [:, 1 ])
193
+ plt .plot (epoch_lrs [:, 0 ], epoch_lrs [:, 1 ], "-." )
165
194
plt .show ()
166
195
167
196
168
197
if __name__ == "__main__" :
169
198
from mmcv .runner import obj_from_dict
170
- import sys
171
199
import os .path as osp
172
200
from torchvision .models import resnet18
173
201
import matplotlib .pyplot as plt
0 commit comments