Skip to content

Commit 3c44f5c

Browse files
committed
add GoogleNet
1 parent f77f6a4 commit 3c44f5c

File tree

4 files changed

+107
-8
lines changed

4 files changed

+107
-8
lines changed

main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@
6666
# Net = DenseNet121()
6767
# Net = DenseNet161()
6868
# Net = DenseNet169()
69-
Net = DenseNet201()
69+
# Net = DenseNet201()
7070

71+
Net = GoogLeNet()
7172

7273
if GPU_IN_USE:
7374
Net.cuda()

models/GoogleNet.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Inception(nn.Module):
6+
def __init__(self, in_planes, kernel_1_out, kernel_3_in, kernel_3_out, kernel_5_in, kernel_5_out, pool_planes):
7+
super(Inception, self).__init__()
8+
# 1x1 conv branch
9+
self.b1 = nn.Sequential(
10+
nn.Conv2d(in_planes, kernel_1_out, kernel_size=1),
11+
nn.BatchNorm2d(kernel_1_out),
12+
nn.ReLU(True),
13+
)
14+
15+
# 1x1 conv -> 3x3 conv branch
16+
self.b2 = nn.Sequential(
17+
nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
18+
nn.BatchNorm2d(kernel_3_in),
19+
nn.ReLU(True),
20+
nn.Conv2d(kernel_3_in, kernel_3_out, kernel_size=3, padding=1),
21+
nn.BatchNorm2d(kernel_3_out),
22+
nn.ReLU(True),
23+
)
24+
25+
# 1x1 conv -> 5x5 conv branch
26+
self.b3 = nn.Sequential(
27+
nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
28+
nn.BatchNorm2d(kernel_5_in),
29+
nn.ReLU(True),
30+
nn.Conv2d(kernel_5_in, kernel_5_out, kernel_size=3, padding=1),
31+
nn.BatchNorm2d(kernel_5_out),
32+
nn.ReLU(True),
33+
nn.Conv2d(kernel_5_out, kernel_5_out, kernel_size=3, padding=1),
34+
nn.BatchNorm2d(kernel_5_out),
35+
nn.ReLU(True),
36+
)
37+
38+
# 3x3 pool -> 1x1 conv branch
39+
self.b4 = nn.Sequential(
40+
nn.MaxPool2d(3, stride=1, padding=1),
41+
nn.Conv2d(in_planes, pool_planes, kernel_size=1),
42+
nn.BatchNorm2d(pool_planes),
43+
nn.ReLU(True),
44+
)
45+
46+
def forward(self, x):
47+
y1 = self.b1(x)
48+
y2 = self.b2(x)
49+
y3 = self.b3(x)
50+
y4 = self.b4(x)
51+
return torch.cat([y1,y2,y3,y4], 1)
52+
53+
54+
class GoogLeNet(nn.Module):
55+
def __init__(self):
56+
super(GoogLeNet, self).__init__()
57+
self.pre_layers = nn.Sequential(
58+
nn.Conv2d(3, 192, kernel_size=3, padding=1),
59+
nn.BatchNorm2d(192),
60+
nn.ReLU(True),
61+
)
62+
63+
self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
64+
self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
65+
66+
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
67+
68+
self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
69+
self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
70+
self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
71+
self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
72+
self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
73+
74+
self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
75+
self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
76+
77+
self.avgpool = nn.AvgPool2d(8, stride=1)
78+
self.linear = nn.Linear(1024, 10)
79+
80+
def forward(self, x):
81+
out = self.pre_layers(x)
82+
out = self.a3(out)
83+
out = self.b3(out)
84+
out = self.max_pool(out)
85+
out = self.a4(out)
86+
out = self.b4(out)
87+
out = self.c4(out)
88+
out = self.d4(out)
89+
out = self.e4(out)
90+
out = self.max_pool(out)
91+
out = self.a5(out)
92+
out = self.b5(out)
93+
out = self.avgpool(out)
94+
out = out.view(out.size(0), -1)
95+
out = self.linear(out)
96+
return out

models/LeNet.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch.nn as nn
2-
import torch.nn.functional as F
2+
import torch.nn.functional as func
3+
34

45
class LeNet(nn.Module):
56
def __init__(self):
@@ -11,12 +12,12 @@ def __init__(self):
1112
self.fc3 = nn.Linear(84, 10)
1213

1314
def forward(self, x):
14-
x = F.relu(self.conv1(x))
15-
x = F.max_pool2d(x, 2)
16-
x = F.relu(self.conv2(x))
17-
x = F.max_pool2d(x, 2)
15+
x = func.relu(self.conv1(x))
16+
x = func.max_pool2d(x, 2)
17+
x = func.relu(self.conv2(x))
18+
x = func.max_pool2d(x, 2)
1819
x = x.view(x.size(0), -1)
19-
x = F.relu(self.fc1(x))
20-
x = F.relu(self.fc2(x))
20+
x = func.relu(self.fc1(x))
21+
x = func.relu(self.fc2(x))
2122
x = self.fc3(x)
2223
return x

models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .ResNet import *
44
from .LeNet import *
55
from .DenseNet import *
6+
from .GoogleNet import *

0 commit comments

Comments
 (0)