Skip to content

Commit 1e49822

Browse files
committed
refactor(structure): add the solver along with the run script
refactor the script into a sovler to keep eveerything more organized
1 parent fee201d commit 1e49822

File tree

3 files changed

+147
-172
lines changed

3 files changed

+147
-172
lines changed

main.py

+138-162
Original file line numberDiff line numberDiff line change
@@ -11,166 +11,142 @@
1111
from misc import progress_bar
1212

1313

14-
# ===========================================================
15-
# Global variables
16-
# ===========================================================
17-
EPOCH = 200 # number of times for each run-through
18-
BATCH_SIZE = 100 # number of images for each epoch
19-
ACCURACY = 0 # overall prediction accuracy
20-
GPU_IN_USE = torch.cuda.is_available() # whether using GPU
21-
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 10 classes containing in CIFAR-10 dataset
22-
23-
24-
# ===========================================================
25-
# parser initialization
26-
# ===========================================================
27-
parser = argparse.ArgumentParser(description="cifar-10 with PyTorch")
28-
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
29-
parser.add_argument('--epoch', default=EPOCH, type=int, help='number of epochs tp train for')
30-
parser.add_argument('--trainBatchSize', default=BATCH_SIZE, type=int, help='training batch size')
31-
parser.add_argument('--testBatchSize', default=BATCH_SIZE, type=int, help='testing batch size')
32-
args = parser.parse_args()
33-
34-
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()]) # dataset training transform
35-
test_transform = transforms.Compose([transforms.ToTensor()]) # dataset testing transform
36-
37-
38-
# ===========================================================
39-
# Prepare train dataset & test dataset
40-
# ===========================================================
41-
print("***** prepare data ******")
42-
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
43-
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.trainBatchSize, shuffle=True)
44-
45-
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
46-
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.testBatchSize, shuffle=False)
47-
print("data preparation......Finished")
48-
49-
# ===========================================================
50-
# Prepare model
51-
# ===========================================================
52-
if GPU_IN_USE:
53-
device = torch.device('cuda')
54-
cudnn.benchmark = True
55-
else:
56-
device = torch.device('cpu')
57-
58-
print("\n***** prepare model *****")
59-
# Net = LeNet().to(device)
60-
61-
# Net = AlexNet().to(device)
62-
63-
# Net = VGG11().to(device)
64-
# Net = VGG13().to(device)
65-
# Net = VGG16().to(device)
66-
# Net = VGG19().to(device)
67-
68-
# Net = GoogLeNet().to(device)
69-
70-
# Net = resnet18().to(device)
71-
# Net = resnet34().to(device)
72-
# Net = resnet50().to(device)
73-
# Net = resnet101().to(device)
74-
# Net = resnet152().to(device)
75-
76-
# Net = DenseNet121().to(device)
77-
# Net = DenseNet161().to(device)
78-
# Net = DenseNet169().to(device)
79-
# Net = DenseNet201().to(device)
80-
81-
Net = WideResNet(depth=28, num_classes=10).to(device)
82-
83-
optimizer = optim.Adam(Net.parameters(), lr=args.lr) # Adam optimization
84-
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75, 150], gamma=0.5) # lr decay
85-
loss_function = nn.CrossEntropyLoss()
86-
print("model preparation......Finished")
87-
88-
89-
# Train
90-
# ===========================================================
91-
# data: [torch.cuda.FloatTensor of size 100x3x32x32 (GPU 0)]
92-
# target: [torch.cuda.LongTensor of size 100 (GPU 0)]
93-
# output: [torch.cuda.FloatTensor of size 100x10 (GPU 0)]
94-
# prediction: [[torch.cuda.LongTensor of size 100 (GPU 0)],
95-
# [torch.cuda.LongTensor of size 100 (GPU 0)]]
96-
# ===========================================================
97-
def train():
98-
print("train:")
99-
Net.train()
100-
train_loss = 0
101-
train_correct = 0
102-
total = 0
103-
104-
for batch_num, (data, target) in enumerate(train_loader):
105-
data, target = data.to(device), target.to(device)
106-
optimizer.zero_grad()
107-
output = Net(data)
108-
loss = loss_function(output, target)
109-
loss.backward()
110-
optimizer.step()
111-
train_loss += loss.item()
112-
prediction = torch.max(output, 1) # second param "1" represents the dimension to be reduced
113-
total += target.size(0)
114-
115-
# train_correct incremented by one if predicted right
116-
train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())
117-
118-
progress_bar(batch_num, len(train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
119-
% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))
120-
121-
return train_loss, train_correct / total
122-
123-
124-
# test
125-
# ===========================================================
126-
# data: [torch.cuda.FloatTensor of size 100x3x32x32 (GPU 0)]
127-
# target: [torch.cuda.LongTensor of size 100 (GPU 0)]
128-
# output: [torch.cuda.FloatTensor of size 100x10 (GPU 0)]
129-
# prediction: [[torch.cuda.LongTensor of size 100 (GPU 0)],
130-
# [torch.cuda.LongTensor of size 100 (GPU 0)]]
131-
# ===========================================================
132-
def test():
133-
print("test:")
134-
Net.eval()
135-
test_loss = 0
136-
test_correct = 0
137-
total = 0
138-
139-
with torch.no_grad():
140-
for batch_num, (data, target) in enumerate(test_loader):
141-
data, target = data.to(device), target.to(device)
142-
output = Net(data)
143-
loss = loss_function(output, target)
144-
test_loss += loss.item()
145-
prediction = torch.max(output, 1)
14+
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser(description="cifar-10 with PyTorch")
19+
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
20+
parser.add_argument('--epoch', default=200, type=int, help='number of epochs tp train for')
21+
parser.add_argument('--trainBatchSize', default=100, type=int, help='training batch size')
22+
parser.add_argument('--testBatchSize', default=100, type=int, help='testing batch size')
23+
parser.add_argument('--cuda', default=torch.cuda.is_available(), type=bool, help='whether cuda is in use')
24+
args = parser.parse_args()
25+
26+
solver = Solver(args)
27+
solver.run()
28+
29+
30+
class Solver(object):
31+
def __init__(self, config):
32+
self.model = None
33+
self.lr = config.lr
34+
self.epochs = config.epoch
35+
self.train_batch_size = config.trainBatchSize
36+
self.test_batch_size = config.testBatchSize
37+
self.criterion = None
38+
self.optimizer = None
39+
self.scheduler = None
40+
self.device = None
41+
self.cuda = config.cuda
42+
self.train_loader = None
43+
self.test_loader = None
44+
45+
def load_data(self):
46+
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
47+
test_transform = transforms.Compose([transforms.ToTensor()])
48+
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
49+
self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)
50+
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
51+
self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)
52+
53+
def load_model(self):
54+
if self.cuda:
55+
self.device = torch.device('cuda')
56+
cudnn.benchmark = True
57+
else:
58+
self.device = torch.device('cpu')
59+
60+
# self.model = LeNet().to(self.device)
61+
# self.model = AlexNet().to(self.device)
62+
# self.model = VGG11().to(self.device)
63+
# self.model = VGG13().to(self.device)
64+
# self.model = VGG16().to(self.device)
65+
# self.model = VGG19().to(self.device)
66+
# self.model = GoogLeNet().to(self.device)
67+
# self.model = resnet18().to(self.device)
68+
# self.model = resnet34().to(self.device)
69+
# self.model = resnet50().to(self.device)
70+
# self.model = resnet101().to(self.device)
71+
# self.model = resnet152().to(self.device)
72+
# self.model = DenseNet121().to(self.device)
73+
# self.model = DenseNet161().to(self.device)
74+
# self.model = DenseNet169().to(self.device)
75+
# self.model = DenseNet201().to(self.device)
76+
self.model = WideResNet(depth=28, num_classes=10).to(self.device)
77+
78+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
79+
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[75, 150], gamma=0.5)
80+
self.criterion = nn.CrossEntropyLoss().to(self.device)
81+
82+
def train(self):
83+
print("train:")
84+
self.model.train()
85+
train_loss = 0
86+
train_correct = 0
87+
total = 0
88+
89+
for batch_num, (data, target) in enumerate(self.train_loader):
90+
data, target = data.to(self.device), target.to(self.device)
91+
self.optimizer.zero_grad()
92+
output = self.model(data)
93+
loss = self.criterion(output, target)
94+
loss.backward()
95+
self.optimizer.step()
96+
train_loss += loss.item()
97+
prediction = torch.max(output, 1) # second param "1" represents the dimension to be reduced
14698
total += target.size(0)
147-
test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())
148-
149-
progress_bar(batch_num, len(test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
150-
% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))
151-
152-
return test_loss, test_correct / total
153-
154-
155-
# ===========================================================
156-
# Save model
157-
# ===========================================================
158-
def save():
159-
model_out_path = "model.pth"
160-
torch.save(Net, model_out_path)
161-
print("Checkpoint saved to {}".format(model_out_path))
162-
163-
164-
# ===========================================================
165-
# training and save model
166-
# ===========================================================
167-
for epoch in range(1, args.epoch + 1):
168-
scheduler.step(epoch)
169-
print("\n===> epoch: %d/200" % epoch)
170-
train_result = train()
171-
print(train_result)
172-
test_result = test()
173-
ACCURACY = max(ACCURACY, test_result[1])
174-
if epoch == args.epoch:
175-
print("===> BEST ACC. PERFORMANCE: %.3f%%" % (ACCURACY * 100))
176-
save()
99+
100+
# train_correct incremented by one if predicted right
101+
train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())
102+
103+
progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
104+
% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))
105+
106+
return train_loss, train_correct / total
107+
108+
def test(self):
109+
print("test:")
110+
self.model.eval()
111+
test_loss = 0
112+
test_correct = 0
113+
total = 0
114+
115+
with torch.no_grad():
116+
for batch_num, (data, target) in enumerate(self.test_loader):
117+
data, target = data.to(self.device), target.to(self.device)
118+
output = self.model(data)
119+
loss = self.criterion(output, target)
120+
test_loss += loss.item()
121+
prediction = torch.max(output, 1)
122+
total += target.size(0)
123+
test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())
124+
125+
progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
126+
% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))
127+
128+
return test_loss, test_correct / total
129+
130+
def save(self):
131+
model_out_path = "model.pth"
132+
torch.save(self.model, model_out_path)
133+
print("Checkpoint saved to {}".format(model_out_path))
134+
135+
def run(self):
136+
self.load_data()
137+
self.load_model()
138+
accuracy = 0
139+
for epoch in range(1, self.epochs + 1):
140+
self.scheduler.step(epoch)
141+
print("\n===> epoch: %d/200" % epoch)
142+
train_result = self.train()
143+
print(train_result)
144+
test_result = self.test()
145+
accuracy = max(accuracy, test_result[1])
146+
if epoch == self.epochs:
147+
print("===> BEST ACC. PERFORMANCE: %.3f%%" % (accuracy * 100))
148+
self.save()
149+
150+
151+
if __name__ == '__main__':
152+
main()

misc.py

-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def progress_bar(current, total, msg=None):
4545
sys.stdout.flush()
4646

4747

48-
# return the formatted time
4948
def format_time(seconds):
5049
days = int(seconds / 3600/24)
5150
seconds = seconds - days*3600*24

models/WideResNet.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,24 @@ def forward(self, x):
5252
class WideResNet(nn.Module):
5353
def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0):
5454
super(WideResNet, self).__init__()
55-
nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55+
n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
5656
assert ((depth - 4) % 6 == 0)
57-
n = (depth - 4) / 6
57+
n = int((depth - 4) / 6)
5858
block = BasicBlock
5959
# 1st conv before any network block
60-
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60+
self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1,
6161
padding=1, bias=False)
6262
# 1st block
63-
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, drop_rate)
63+
self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1, drop_rate)
6464
# 2nd block
65-
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, drop_rate)
65+
self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2, drop_rate)
6666
# 3rd block
67-
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, drop_rate)
67+
self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2, drop_rate)
6868
# global average pooling and classifier
69-
self.bn1 = nn.BatchNorm2d(nChannels[3])
69+
self.bn1 = nn.BatchNorm2d(n_channels[3])
7070
self.relu = nn.ReLU(inplace=True)
71-
self.fc = nn.Linear(nChannels[3], num_classes)
72-
self.nChannels = nChannels[3]
71+
self.fc = nn.Linear(n_channels[3], num_classes)
72+
self.nChannels = n_channels[3]
7373

7474
for m in self.modules():
7575
if isinstance(m, nn.Conv2d):

0 commit comments

Comments
 (0)