1
- from typing import Callable , Optional
1
+ from typing import Callable , List , Optional , Type
2
2
3
3
import torch
4
4
from torch import nn
@@ -26,8 +26,8 @@ def __init__(
26
26
in_channels : int ,
27
27
mid_channels : int ,
28
28
stride : int = 1 ,
29
- conv_layer : type [ConvBnAct ] = ConvBnAct ,
30
- act_fn : type [nn .Module ] = nn .ReLU ,
29
+ conv_layer : Type [ConvBnAct ] = ConvBnAct ,
30
+ act_fn : Type [nn .Module ] = nn .ReLU ,
31
31
zero_bn : bool = True ,
32
32
bn_1st : bool = True ,
33
33
groups : int = 1 ,
@@ -150,16 +150,16 @@ def __init__(
150
150
in_channels : int ,
151
151
mid_channels : int ,
152
152
stride : int = 1 ,
153
- conv_layer : type [ConvBnAct ] = ConvBnAct ,
154
- act_fn : type [nn .Module ] = nn .ReLU ,
153
+ conv_layer : Type [ConvBnAct ] = ConvBnAct ,
154
+ act_fn : Type [nn .Module ] = nn .ReLU ,
155
155
zero_bn : bool = True ,
156
156
bn_1st : bool = True ,
157
157
groups : int = 1 ,
158
158
dw : bool = False ,
159
159
div_groups : Optional [int ] = None ,
160
160
pool : Optional [Callable [[], nn .Module ]] = None ,
161
- se : Optional [type [nn .Module ]] = None ,
162
- sa : Optional [type [nn .Module ]] = None ,
161
+ se : Optional [Type [nn .Module ]] = None ,
162
+ sa : Optional [Type [nn .Module ]] = None ,
163
163
):
164
164
super ().__init__ ()
165
165
# pool defined at ModelConstructor.
@@ -265,7 +265,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
265
265
def make_stem (cfg : ModelCfg ) -> nn .Sequential :
266
266
"""Create xResnet stem -> 3 conv 3*3 instead of 1 conv 7*7"""
267
267
len_stem = len (cfg .stem_sizes )
268
- stem : list [ tuple [ str , nn . Module ]] = [
268
+ stem : ListStrMod = [
269
269
(
270
270
f"conv_{ i } " ,
271
271
cfg .conv_layer (
@@ -341,11 +341,11 @@ class XResNet(ModelConstructor):
341
341
make_layer : Callable [[ModelCfg , int ], ModSeq ] = make_layer
342
342
make_body : Callable [[ModelCfg ], ModSeq ] = make_body
343
343
make_head : Callable [[ModelCfg ], ModSeq ] = make_head
344
- block : type [nn .Module ] = XResBlock
344
+ block : Type [nn .Module ] = XResBlock
345
345
346
346
347
347
class XResNet34 (XResNet ):
348
- layers : list [int ] = [3 , 4 , 6 , 3 ]
348
+ layers : List [int ] = [3 , 4 , 6 , 3 ]
349
349
350
350
351
351
class XResNet50 (XResNet34 ):
@@ -357,13 +357,13 @@ class YaResNet(XResNet):
357
357
YaResBlock, Mish activation, custom stem.
358
358
"""
359
359
360
- block : type [nn .Module ] = YaResBlock
361
- stem_sizes : list [int ] = [3 , 32 , 64 , 64 ]
362
- act_fn : type [nn .Module ] = nn .Mish
360
+ block : Type [nn .Module ] = YaResBlock
361
+ stem_sizes : List [int ] = [3 , 32 , 64 , 64 ]
362
+ act_fn : Type [nn .Module ] = nn .Mish
363
363
364
364
365
365
class YaResNet34 (YaResNet ):
366
- layers : list [int ] = [3 , 4 , 6 , 3 ]
366
+ layers : List [int ] = [3 , 4 , 6 , 3 ]
367
367
368
368
369
369
class YaResNet50 (YaResNet34 ):
0 commit comments