Skip to content

Commit b3792cf

Browse files
committed
typing at field_validation
1 parent aa1dfea commit b3792cf

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

src/model_constructor/model_constructor.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
}
2626

2727

28-
nnModule = Union[type[nn.Module], Callable[[], nn.Module], str]
28+
nnModule = Union[type[nn.Module], Callable[[], nn.Module]]
2929

3030

3131
class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
@@ -34,50 +34,48 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
3434
name: Optional[str] = None
3535
in_chans: int = 3
3636
num_classes: int = 1000
37-
block: nnModule = BasicBlock
38-
conv_layer: nnModule = ConvBnAct
37+
block: Union[nnModule, str] = BasicBlock
38+
conv_layer: Union[nnModule, str] = ConvBnAct
3939
block_sizes: list[int] = [64, 128, 256, 512]
4040
layers: list[int] = [2, 2, 2, 2]
41-
norm: nnModule = nn.BatchNorm2d
42-
act_fn: nnModule = nn.ReLU
43-
pool: Optional[nnModule] = None
41+
norm: Union[nnModule, str] = nn.BatchNorm2d
42+
act_fn: Union[nnModule, str] = nn.ReLU
43+
pool: Union[nnModule, str, None] = None
4444
expansion: int = 1
4545
groups: int = 1
4646
dw: bool = False
4747
div_groups: Optional[int] = None
48-
sa: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
49-
se: Union[bool, type[nn.Module], Callable[[], nn.Module]] = False
48+
sa: Union[bool, nnModule, str] = False
49+
se: Union[bool, nnModule, str] = False
5050
se_module: Optional[bool] = None
5151
se_reduction: Optional[int] = None
5252
bn_1st: bool = True
5353
zero_bn: bool = True
5454
stem_stride_on: int = 0
5555
stem_sizes: list[int] = [64]
56-
stem_pool: Optional[nnModule] = partial(
56+
stem_pool: Union[nnModule, str, None] = partial(
5757
nn.MaxPool2d, kernel_size=3, stride=2, padding=1
5858
)
5959
stem_bn_end: bool = False
6060

6161
@field_validator("act_fn", "block", "conv_layer", "norm", "pool", "stem_pool")
6262
def set_modules( # pylint: disable=no-self-argument
63-
cls, value: Union[type[nn.Module], str], info: FieldValidationInfo,
63+
cls, value: Union[nnModule, str],
6464
) -> nnModule:
6565
"""Check values, if string, convert to nn.Module."""
6666
if is_module(value):
6767
return value
68-
if isinstance(value, str):
69-
return instantiate_module(value)
70-
# raise ValueError(f"{info.field_name} must be str or nn.Module")
68+
return instantiate_module(value)
7169

7270
@field_validator("se", "sa")
7371
def set_se( # pylint: disable=no-self-argument
74-
cls, value: Union[bool, type[nn.Module]], info: FieldValidationInfo,
72+
cls, value: Union[bool, nnModule, str], info: FieldValidationInfo,
7573
) -> nnModule:
7674
if isinstance(value, (int, bool)):
7775
return DEFAULT_SE_SA[info.field_name]
7876
if is_module(value):
7977
return value
80-
# raise ValueError(f"{info.field_name} must be bool or nn.Module") # no need - check at init
78+
return instantiate_module(value)
8179

8280
@field_validator("se_module", "se_reduction") # pragma: no cover
8381
def deprecation_warning( # pylint: disable=no-self-argument

0 commit comments

Comments
 (0)