Skip to content

Commit 847e41f

Browse files
committed
support cyclic schedule
1 parent 65b5cfd commit 847e41f

File tree

5 files changed

+121
-75
lines changed

5 files changed

+121
-75
lines changed

.circleci/config.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
_run:
2+
install_system_deps: &install_system_deps
3+
name: install_system_deps
4+
command: |
5+
sudo apt-get update
6+
sudo apt-get install -y cmake python-pip python-dev build-essential protobuf-compiler libprotoc-dev
7+
8+
version: 2
9+
jobs:
10+
python_lint:
11+
docker:
12+
- image: circleci/python:3.7
13+
steps:
14+
- checkout
15+
- run: *install_system_deps
16+
- run:
17+
name: setup lint
18+
command: |
19+
sudo pip install black
20+
- run:
21+
name: run black
22+
command: black -l 120 . --check --diff
23+
24+
workflows:
25+
version: 2
26+
build_and_test:
27+
jobs:
28+
- python_lint

.style.yapf

Lines changed: 0 additions & 10 deletions
This file was deleted.

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 Gu Wang
3+
Copyright (c) 2019-2021 Gu Wang
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Flat and anneal lr scheduler in pytorch
1+
## (WarmUp) (Cyclic) Flat and Anneal LR Scheduler in PyTorch
22

33
`warmup_method`:
44
* `linear`

lr_scheduler.py

Lines changed: 91 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,47 @@
33
import torch
44
from torch.optim import Optimizer
55
from math import pi, cos
6-
7-
8-
def flat_and_anneal_lr_scheduler(optimizer, total_iters, \
9-
warmup_iters=0, warmup_factor=0.1, warmup_method='linear', \
10-
anneal_point=0.72, anneal_method='cosine', target_lr_factor=0, \
11-
poly_power=1.0, step_gamma=0.1, steps=[0.5, 0.75], \
6+
import warnings
7+
8+
9+
def flat_and_anneal_lr_scheduler(
10+
optimizer,
11+
total_iters,
12+
warmup_iters=0,
13+
warmup_factor=0.1,
14+
warmup_method="linear",
15+
anneal_point=0.72,
16+
anneal_method="cosine",
17+
target_lr_factor=0,
18+
poly_power=1.0,
19+
step_gamma=0.1,
20+
steps=[2 / 3.0, 8 / 9.0],
21+
return_function=False,
1222
):
13-
"""https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py
23+
"""Ref: https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py.
24+
1425
warmup_initial_lr = warmup_factor * base_lr
1526
target_lr = base_lr * target_lr_factor
27+
total_iters: cycle length; set to max_iter to get a one cycle schedule.
1628
"""
1729
if warmup_method not in ("constant", "linear"):
18-
raise ValueError("Only 'constant' or 'linear' warmup_method accepted,"
19-
"got {}".format(warmup_method))
30+
raise ValueError("Only 'constant' or 'linear' warmup_method accepted," "got {}".format(warmup_method))
2031

21-
if anneal_method not in ("cosine", "linear", "poly", "exp", "step"):
22-
raise ValueError("Only 'cosine', 'linear', 'poly', 'exp' or 'step' anneal_method accepted,"
23-
"got {}".format(anneal_method))
32+
if anneal_method not in ("cosine", "linear", "poly", "exp", "step", "none"):
33+
raise ValueError(
34+
"Only 'cosine', 'linear', 'poly', 'exp', 'step' or 'none' anneal_method accepted,"
35+
"got {}".format(anneal_method)
36+
)
2437

25-
if anneal_method == 'step':
38+
if anneal_method == "step":
2639
if any([_step < warmup_iters / total_iters or _step > 1 for _step in steps]):
27-
raise ValueError("error in steps: {}. warmup_iters: {} total_iters: {}."
28-
"steps should be in ({},1)".format(steps, warmup_iters, total_iters, \
29-
warmup_iters / total_iters))
30-
if steps != sorted(steps):
31-
raise ValueError("steps {} is not in ascending order.")
32-
print("ignore anneal_point when using step anneal_method")
40+
raise ValueError(
41+
"error in steps: {}. warmup_iters: {} total_iters: {}."
42+
"steps should be in ({},1)".format(steps, warmup_iters, total_iters, warmup_iters / total_iters)
43+
)
44+
if list(steps) != sorted(steps):
45+
raise ValueError("steps {} is not in ascending order.".format(steps))
46+
warnings.warn("ignore anneal_point when using step anneal_method")
3347
anneal_start = steps[0] * total_iters
3448
else:
3549
if anneal_point > 1 or anneal_point < 0:
@@ -38,91 +52,108 @@ def flat_and_anneal_lr_scheduler(optimizer, total_iters, \
3852

3953
def f(x): # x is the iter in lr scheduler, return the lr_factor
4054
# the final lr is warmup_factor * base_lr
55+
x = x % total_iters # cyclic
4156
if x < warmup_iters:
42-
if warmup_method == 'linear':
57+
if warmup_method == "linear":
4358
alpha = float(x) / warmup_iters
4459
return warmup_factor * (1 - alpha) + alpha
45-
elif warmup_method == 'constant':
60+
elif warmup_method == "constant":
4661
return warmup_factor
4762
elif x >= anneal_start:
48-
if anneal_method == 'step':
63+
if anneal_method == "step":
4964
# ignore anneal_point and target_lr_factor
5065
milestones = [_step * total_iters for _step in steps]
51-
lr_factor = step_gamma**bisect_right(milestones, float(x))
52-
elif anneal_method == 'cosine':
66+
lr_factor = step_gamma ** bisect_right(milestones, float(x))
67+
elif anneal_method == "cosine":
5368
# slow --> fast --> slow
54-
lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor) * \
55-
(1 + cos(pi * ((float(x) - anneal_start) / (total_iters - anneal_start))))
56-
elif anneal_method == 'linear':
69+
lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor) * (
70+
1 + cos(pi * ((float(x) - anneal_start) / (total_iters - anneal_start)))
71+
)
72+
elif anneal_method == "linear":
5773
# (y-m) / (B-x) = (1-m) / (B-A)
58-
lr_factor = target_lr_factor + (1 - target_lr_factor) * \
59-
(total_iters - float(x)) / (total_iters - anneal_start)
60-
elif anneal_method == 'poly':
74+
lr_factor = target_lr_factor + (1 - target_lr_factor) * (total_iters - float(x)) / (
75+
total_iters - anneal_start
76+
)
77+
elif anneal_method == "poly":
6178
# slow --> fast if poly_power < 1
6279
# fast --> slow if poly_power > 1
6380
# when poly_power == 1.0, it is the same with linear
64-
lr_factor = target_lr_factor + (1 - target_lr_factor) * \
65-
((total_iters - float(x)) / (total_iters - anneal_start)) ** poly_power
66-
elif anneal_method == 'exp':
81+
lr_factor = (
82+
target_lr_factor
83+
+ (1 - target_lr_factor) * ((total_iters - float(x)) / (total_iters - anneal_start)) ** poly_power
84+
)
85+
elif anneal_method == "exp":
6786
# fast --> slow
6887
# do not decay too much, especially if lr_end == 0, lr will be
6988
# 0 at anneal iter, so we should avoid that
7089
_target_lr_factor = max(target_lr_factor, 5e-3)
71-
lr_factor = _target_lr_factor ** ( \
72-
(float(x) - anneal_start) / (total_iters - anneal_start))
90+
lr_factor = _target_lr_factor ** ((float(x) - anneal_start) / (total_iters - anneal_start))
7391
else:
7492
lr_factor = 1
7593
return lr_factor
7694
else: # warmup_iter <= x < anneal_start_iter
7795
return 1
7896

79-
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
97+
if return_function:
98+
return torch.optim.lr_scheduler.LambdaLR(optimizer, f), f
99+
else:
100+
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
80101

81102

82103
def test_flat_and_anneal():
83104
from mmcv import Config
84105
import numpy as np
85-
model = resnet18()
86106

87-
optimizer_cfg = dict(type='Adam', lr=1e-4, weight_decay=0)
107+
model = resnet18()
108+
base_lr = 1e-4
109+
optimizer_cfg = dict(type="Adam", lr=base_lr, weight_decay=0)
88110
optimizer = obj_from_dict(optimizer_cfg, torch.optim, dict(params=model.parameters()))
89111

90112
# learning policy
91113
total_epochs = 80
92114
epoch_len = 500
93-
total_iters = epoch_len * total_epochs
115+
total_iters = epoch_len * total_epochs // 2
94116
# poly, step, linear, exp, cosine
95117
lr_cfg = Config(
96118
dict(
97-
anneal_method='cosine',
98-
warmup_method='linear',
119+
# anneal_method="cosine",
120+
# anneal_method="linear",
121+
# anneal_method="poly",
122+
# anneal_method="exp",
123+
anneal_method="step",
124+
warmup_method="linear",
99125
step_gamma=0.1,
100126
warmup_factor=0.1,
101127
warmup_iters=800,
102128
poly_power=5,
103-
target_lr_factor=0.,
129+
target_lr_factor=0.0,
104130
steps=[0.5, 0.75, 0.9],
105131
anneal_point=0.72,
106-
))
132+
)
133+
)
107134

108135
# scheduler = build_scheduler(lr_config, optimizer, epoch_length)
109136
scheduler = flat_and_anneal_lr_scheduler(
110-
optimizer=optimizer, total_iters=total_iters, \
111-
warmup_method=lr_cfg.warmup_method, warmup_factor=lr_cfg.warmup_factor, \
112-
warmup_iters=lr_cfg.warmup_iters, \
113-
anneal_method=lr_cfg.anneal_method, anneal_point=lr_cfg.anneal_point, \
114-
target_lr_factor=lr_cfg.target_lr_factor, \
115-
poly_power=lr_cfg.poly_power, \
116-
step_gamma=lr_cfg.step_gamma, steps=lr_cfg.steps, \
137+
optimizer=optimizer,
138+
total_iters=total_iters,
139+
warmup_method=lr_cfg.warmup_method,
140+
warmup_factor=lr_cfg.warmup_factor,
141+
warmup_iters=lr_cfg.warmup_iters,
142+
anneal_method=lr_cfg.anneal_method,
143+
anneal_point=lr_cfg.anneal_point,
144+
target_lr_factor=lr_cfg.target_lr_factor,
145+
poly_power=lr_cfg.poly_power,
146+
step_gamma=lr_cfg.step_gamma,
147+
steps=lr_cfg.steps,
117148
)
118-
print('start lr: {}'.format(scheduler.get_lr()))
149+
print("start lr: {}".format(scheduler.get_lr()))
119150
steps = []
120151
lrs = []
121152

122153
epoch_lrs = []
123154
global_step = 0
124155

125-
start_epoch = 20
156+
start_epoch = 0
126157
for epoch in range(start_epoch):
127158
for batch in range(epoch_len):
128159
scheduler.step() # when no state_dict availble
@@ -133,41 +164,38 @@ def test_flat_and_anneal():
133164
# scheduler.step(epoch)
134165
# print(type(scheduler.get_lr()[0]))
135166
# import pdb;pdb.set_trace()
136-
epoch_lrs.append([epoch,
137-
scheduler.get_lr()[0]]) # only get the first lr (maybe a group of lrs)
167+
epoch_lrs.append([epoch, scheduler.get_lr()[0]]) # only get the first lr (maybe a group of lrs)
138168
for batch in range(epoch_len):
139169
# if global_step < lr_config['warmup_iters']:
140170
# scheduler.step(global_step)
141171
cur_lr = scheduler.get_lr()[0]
142172
if global_step == 0 or (len(lrs) >= 1 and cur_lr != lrs[-1]):
143-
print('epoch {}, batch: {}, global_step:{} lr: {}'.format(
144-
epoch, batch, global_step, cur_lr))
173+
print("epoch {}, batch: {}, global_step:{} lr: {}".format(epoch, batch, global_step, cur_lr))
145174
steps.append(global_step)
146175
lrs.append(cur_lr)
147176
global_step += 1
148177
scheduler.step() # usually after optimizer.step()
149178
# print(epoch_lrs)
150179
# import pdb;pdb.set_trace()
151-
epoch_lrs.append([total_epochs, scheduler.get_lr()[0]])
180+
# epoch_lrs.append([total_epochs, scheduler.get_lr()[0]])
152181

153182
epoch_lrs = np.asarray(epoch_lrs, dtype=np.float32)
154183
for i in range(len(epoch_lrs)):
155-
print('{:02d} {}'.format(int(epoch_lrs[i][0]), epoch_lrs[i][1]))
184+
print("{:02d} {}".format(int(epoch_lrs[i][0]), epoch_lrs[i][1]))
156185

157-
plt.figure(dpi=200)
158-
plt.suptitle('{}'.format(dict(lr_cfg)), size=4)
186+
plt.figure(dpi=100)
187+
plt.suptitle("{}".format(dict(lr_cfg)), size=4)
159188
plt.subplot(1, 2, 1)
160-
plt.plot(steps, lrs)
189+
plt.plot(steps, lrs, "-.")
161190
# plt.show()
162191
plt.subplot(1, 2, 2)
163192
# print(epoch_lrs.dtype)
164-
plt.plot(epoch_lrs[:, 0], epoch_lrs[:, 1])
193+
plt.plot(epoch_lrs[:, 0], epoch_lrs[:, 1], "-.")
165194
plt.show()
166195

167196

168197
if __name__ == "__main__":
169198
from mmcv.runner import obj_from_dict
170-
import sys
171199
import os.path as osp
172200
from torchvision.models import resnet18
173201
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)