25
25
}
26
26
27
27
28
- nnModule = Union [type [nn .Module ], Callable [[], nn .Module ], str ]
28
+ nnModule = Union [type [nn .Module ], Callable [[], nn .Module ]]
29
29
30
30
31
31
class ModelCfg (Cfg , arbitrary_types_allowed = True , extra = "forbid" ):
@@ -34,50 +34,48 @@ class ModelCfg(Cfg, arbitrary_types_allowed=True, extra="forbid"):
34
34
name : Optional [str ] = None
35
35
in_chans : int = 3
36
36
num_classes : int = 1000
37
- block : nnModule = BasicBlock
38
- conv_layer : nnModule = ConvBnAct
37
+ block : Union [ nnModule , str ] = BasicBlock
38
+ conv_layer : Union [ nnModule , str ] = ConvBnAct
39
39
block_sizes : list [int ] = [64 , 128 , 256 , 512 ]
40
40
layers : list [int ] = [2 , 2 , 2 , 2 ]
41
- norm : nnModule = nn .BatchNorm2d
42
- act_fn : nnModule = nn .ReLU
43
- pool : Optional [nnModule ] = None
41
+ norm : Union [ nnModule , str ] = nn .BatchNorm2d
42
+ act_fn : Union [ nnModule , str ] = nn .ReLU
43
+ pool : Union [nnModule , str , None ] = None
44
44
expansion : int = 1
45
45
groups : int = 1
46
46
dw : bool = False
47
47
div_groups : Optional [int ] = None
48
- sa : Union [bool , type [ nn . Module ], Callable [[], nn . Module ] ] = False
49
- se : Union [bool , type [ nn . Module ], Callable [[], nn . Module ] ] = False
48
+ sa : Union [bool , nnModule , str ] = False
49
+ se : Union [bool , nnModule , str ] = False
50
50
se_module : Optional [bool ] = None
51
51
se_reduction : Optional [int ] = None
52
52
bn_1st : bool = True
53
53
zero_bn : bool = True
54
54
stem_stride_on : int = 0
55
55
stem_sizes : list [int ] = [64 ]
56
- stem_pool : Optional [nnModule ] = partial (
56
+ stem_pool : Union [nnModule , str , None ] = partial (
57
57
nn .MaxPool2d , kernel_size = 3 , stride = 2 , padding = 1
58
58
)
59
59
stem_bn_end : bool = False
60
60
61
61
@field_validator ("act_fn" , "block" , "conv_layer" , "norm" , "pool" , "stem_pool" )
62
62
def set_modules ( # pylint: disable=no-self-argument
63
- cls , value : Union [type [ nn . Module ] , str ], info : FieldValidationInfo ,
63
+ cls , value : Union [nnModule , str ],
64
64
) -> nnModule :
65
65
"""Check values, if string, convert to nn.Module."""
66
66
if is_module (value ):
67
67
return value
68
- if isinstance (value , str ):
69
- return instantiate_module (value )
70
- # raise ValueError(f"{info.field_name} must be str or nn.Module")
68
+ return instantiate_module (value )
71
69
72
70
@field_validator ("se" , "sa" )
73
71
def set_se ( # pylint: disable=no-self-argument
74
- cls , value : Union [bool , type [ nn . Module ] ], info : FieldValidationInfo ,
72
+ cls , value : Union [bool , nnModule , str ], info : FieldValidationInfo ,
75
73
) -> nnModule :
76
74
if isinstance (value , (int , bool )):
77
75
return DEFAULT_SE_SA [info .field_name ]
78
76
if is_module (value ):
79
77
return value
80
- # raise ValueError(f"{info.field_name} must be bool or nn.Module") # no need - check at init
78
+ return instantiate_module ( value )
81
79
82
80
@field_validator ("se_module" , "se_reduction" ) # pragma: no cover
83
81
def deprecation_warning ( # pylint: disable=no-self-argument
0 commit comments