Skip to content

Commit 4c7eec2

Browse files
committed
add yet another function for automatically returning a finetune optimizer
1 parent b6e261d commit 4c7eec2

File tree

4 files changed

+54
-14
lines changed

4 files changed

+54
-14
lines changed

perfusion_pytorch/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from perfusion_pytorch.save_load import (
1616
save,
17-
load,
18-
get_finetune_parameters
17+
load
18+
)
19+
20+
from perfusion_pytorch.optimizer import (
21+
get_finetune_parameters,
22+
get_finetune_optimizer
1923
)

perfusion_pytorch/optimizer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torch.nn import Module
2+
from torch.optim import AdamW, Adam, Optimizer
3+
4+
from beartype import beartype
5+
6+
from perfusion_pytorch.embedding import EmbeddingWrapper
7+
from perfusion_pytorch.perfusion import Rank1EditModule
8+
9+
# helper functions
10+
11+
def exists(val):
12+
return val is not None
13+
14+
# function that automatically finds all the parameters necessary for fine tuning
15+
16+
@beartype
17+
def get_finetune_parameters(text_image_model: Module):
18+
params = []
19+
for module in text_image_model.modules():
20+
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
21+
params.extend(module.parameters())
22+
23+
return params
24+
25+
@beartype
26+
def get_finetune_optimizer(
27+
text_image_model: Module,
28+
lr = 1e-4,
29+
wd = 1e-2,
30+
betas = (0.9, 0.99),
31+
eps = 1e-8,
32+
**kwargs
33+
) -> Optimizer:
34+
params = get_finetune_parameters(text_image_model)
35+
36+
assert len(params) > 0, 'no finetuneable parameters found'
37+
total_params = sum([p.numel() for p in params])
38+
print(f'optimizer {total_params} parameters')
39+
40+
has_weight_decay = wd > 0
41+
adam_klass = AdamW if has_weight_decay else Adam
42+
adam_kwargs = dict(lr = lr, betas = betas, eps = eps)
43+
44+
if has_weight_decay:
45+
adam_kwargs.update(weight_decay = wd)
46+
47+
return adam_klass(params, **adam_kwargs, **kwargs)

perfusion_pytorch/save_load.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@
1414
def exists(val):
1515
return val is not None
1616

17-
# function that automatically finds all the parameters necessary for fine tuning
18-
19-
@beartype
20-
def get_finetune_parameters(text_image_model: Module):
21-
params = []
22-
for module in text_image_model.modules():
23-
if isinstance(module, (EmbeddingWrapper, Rank1EditModule)):
24-
params.extend(module.parameters())
25-
26-
return params
27-
2817
# saving and loading the necessary extra finetuned params
2918

3019
@beartype

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.17',
6+
version = '0.1.18',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)