Skip to content

Commit c77ce18

Browse files
committed
minor update
1 parent c9a1f6d commit c77ce18

File tree

2 files changed

+48
-61
lines changed

2 files changed

+48
-61
lines changed

models/DenseNet.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ class Bottleneck(nn.Module):
99
def __init__(self, in_planes, growth_rate):
1010
super(Bottleneck, self).__init__()
1111
self.bn1 = nn.BatchNorm2d(in_planes)
12-
self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
13-
self.bn2 = nn.BatchNorm2d(4*growth_rate)
14-
self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
12+
self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False)
13+
self.bn2 = nn.BatchNorm2d(4 * growth_rate)
14+
self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
1515

1616
def forward(self, x):
17-
out = self.conv1(func.relu(self.bn1(x)))
18-
out = self.conv2(func.relu(self.bn2(out)))
19-
out = torch.cat([out,x], 1)
20-
return out
17+
x = self.conv1(func.relu(self.bn1(x)))
18+
x = self.conv2(func.relu(self.bn2(x)))
19+
x = torch.cat([x, x], 1)
20+
return x
2121

2222

2323
class Transition(nn.Module):
@@ -27,46 +27,46 @@ def __init__(self, in_planes, out_planes):
2727
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
2828

2929
def forward(self, x):
30-
out = self.conv(func.relu(self.bn(x)))
31-
out = func.avg_pool2d(out, 2)
32-
return out
30+
x = self.conv(func.relu(self.bn(x)))
31+
x = func.avg_pool2d(x, 2)
32+
return x
3333

3434

3535
class DenseNet(nn.Module):
36-
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
36+
def __init__(self, block, num_block, growth_rate=12, reduction=0.5, num_classes=10):
3737
super(DenseNet, self).__init__()
3838
self.growth_rate = growth_rate
3939

4040
num_planes = 2*growth_rate
4141
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
4242

43-
self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
44-
num_planes += nblocks[0]*growth_rate
45-
out_planes = int(math.floor(num_planes*reduction))
43+
self.dense1 = self._make_dense_layers(block, num_planes, num_block[0])
44+
num_planes += num_block[0] * growth_rate
45+
out_planes = int(math.floor(num_planes * reduction))
4646
self.trans1 = Transition(num_planes, out_planes)
4747
num_planes = out_planes
4848

49-
self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
50-
num_planes += nblocks[1]*growth_rate
51-
out_planes = int(math.floor(num_planes*reduction))
49+
self.dense2 = self._make_dense_layers(block, num_planes, num_block[1])
50+
num_planes += num_block[1] * growth_rate
51+
out_planes = int(math.floor(num_planes * reduction))
5252
self.trans2 = Transition(num_planes, out_planes)
5353
num_planes = out_planes
5454

55-
self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
56-
num_planes += nblocks[2]*growth_rate
57-
out_planes = int(math.floor(num_planes*reduction))
55+
self.dense3 = self._make_dense_layers(block, num_planes, num_block[2])
56+
num_planes += num_block[2] * growth_rate
57+
out_planes = int(math.floor(num_planes * reduction))
5858
self.trans3 = Transition(num_planes, out_planes)
5959
num_planes = out_planes
6060

61-
self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
62-
num_planes += nblocks[3]*growth_rate
61+
self.dense4 = self._make_dense_layers(block, num_planes, num_block[3])
62+
num_planes += num_block[3] * growth_rate
6363

6464
self.bn = nn.BatchNorm2d(num_planes)
6565
self.linear = nn.Linear(num_planes, num_classes)
6666

67-
def _make_dense_layers(self, block, in_planes, nblock):
67+
def _make_dense_layers(self, block, in_planes, num_block):
6868
layers = []
69-
for i in range(nblock):
69+
for i in range(num_block):
7070
layers.append(block(in_planes, self.growth_rate))
7171
in_planes += self.growth_rate
7272
return nn.Sequential(*layers)
@@ -84,20 +84,20 @@ def forward(self, x):
8484

8585

8686
def DenseNet121():
87-
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
87+
return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32)
8888

8989

9090
def DenseNet169():
91-
return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
91+
return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32)
9292

9393

9494
def DenseNet201():
95-
return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
95+
return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32)
9696

9797

9898
def DenseNet161():
99-
return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
99+
return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48)
100100

101101

102102
def densenet_cifar():
103-
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
103+
return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12)

models/ResNet.py

+19-32
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,6 @@
22
import math
33

44

5-
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6-
'resnet152']
7-
8-
9-
model_urls = {
10-
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11-
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12-
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13-
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14-
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15-
}
16-
17-
185
def conv3x3(in_planes, out_planes, stride=1):
196
# 3x3 convolution with padding
207
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
@@ -36,20 +23,20 @@ def __init__(self, inplanes, planes, stride=1, downsample=None):
3623
def forward(self, x):
3724
residual = x
3825

39-
out = self.conv1(x)
40-
out = self.bn1(out)
41-
out = self.relu(out)
26+
x = self.conv1(x)
27+
x = self.bn1(x)
28+
x = self.relu(x)
4229

43-
out = self.conv2(out)
44-
out = self.bn2(out)
30+
x = self.conv2(x)
31+
x = self.bn2(x)
4532

4633
if self.downsample is not None:
4734
residual = self.downsample(x)
4835

49-
out += residual
50-
out = self.relu(out)
36+
x += residual
37+
x = self.relu(x)
5138

52-
return out
39+
return x
5340

5441

5542
class Bottleneck(nn.Module):
@@ -70,24 +57,24 @@ def __init__(self, inplanes, planes, stride=1, downsample=None):
7057
def forward(self, x):
7158
residual = x
7259

73-
out = self.conv1(x)
74-
out = self.bn1(out)
75-
out = self.relu(out)
60+
x = self.conv1(x)
61+
x = self.bn1(x)
62+
x = self.relu(x)
7663

77-
out = self.conv2(out)
78-
out = self.bn2(out)
79-
out = self.relu(out)
64+
x = self.conv2(x)
65+
x = self.bn2(x)
66+
x = self.relu(x)
8067

81-
out = self.conv3(out)
82-
out = self.bn3(out)
68+
x = self.conv3(x)
69+
x = self.bn3(x)
8370

8471
if self.downsample is not None:
8572
residual = self.downsample(x)
8673

87-
out += residual
88-
out = self.relu(out)
74+
x += residual
75+
x = self.relu(x)
8976

90-
return out
77+
return x
9178

9279

9380
class ResNet(nn.Module):

0 commit comments

Comments
 (0)