Skip to content

Commit 661147e

Browse files
author
laggui
committed
Added python implementation
1 parent caebda2 commit 661147e

File tree

5 files changed

+405
-0
lines changed

5 files changed

+405
-0
lines changed

pytorch/predict.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.backends.cudnn as cudnn
4+
5+
import torchvision.transforms as transforms
6+
7+
from torch import jit
8+
from PIL import Image
9+
10+
import io
11+
import time
12+
import argparse
13+
import cv2
14+
15+
from vgg import VGGNet
16+
17+
# Check device
18+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
# CIFAR-10 classes
20+
classes = ('plane', 'car', 'bird', 'cat',
21+
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
22+
23+
def predict(model, image):
24+
# apply transform and convert BGR -> RGB
25+
x = image[:, :, (2, 1, 0)]
26+
#print('Image shape: {}'.format(x.shape))
27+
# H x W x C -> C x H x W for conv input
28+
x = torch.from_numpy(x).permute(2, 0, 1)
29+
torch.set_printoptions(threshold=5000)
30+
31+
to_norm_tensor = transforms.Compose([
32+
#transforms.ToTensor(),
33+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
34+
])
35+
36+
img_tensor = to_norm_tensor(x.float().div_(255))
37+
#print('Image tensor: {}'.format(img_tensor))
38+
#print('Image tensor shape: {}'.format(img_tensor.shape))
39+
img_tensor.unsqueeze_(0).to(device) # add a dimension for the batch
40+
#print('New shape: {}'.format(img_tensor.shape))
41+
42+
with torch.no_grad():
43+
# forward pass
44+
outputs = model(img_tensor)
45+
score, predicted = outputs.max(1)
46+
#print(outputs)
47+
print('Predicted: {} | {}'.format(classes[predicted.item()], score.item()))
48+
49+
50+
51+
if __name__ == '__main__':
52+
parser = argparse.ArgumentParser(description='VGGNet Predict Tool')
53+
parser.add_argument('mtype', type=str, choices=['pytorch', 'torch-script'], help='Model type')
54+
parser.add_argument('--model', type=str, default='../data/VGG16model.pth', help='Pre-trained model')
55+
parser.add_argument('--image', type=str, default='../data/dog.png', help='Input image')
56+
args = parser.parse_args()
57+
58+
# Model
59+
print('==> Building model...')
60+
if args.mtype == 'pytorch':
61+
model = VGGNet('D-DSM', num_classes=10, input_size=32) # depthwise separable
62+
# Load model
63+
print('==> Loading PyTorch model...')
64+
model.load_state_dict(torch.load(args.model))
65+
model.eval()
66+
model.to(device)
67+
else:
68+
print('==> Loading Torch Script model...')
69+
# Load ScriptModule from io.BytesIO object
70+
with open(args.model, 'rb') as f:
71+
buffer = io.BytesIO(f.read())
72+
model = torch.jit.load(buffer)
73+
print('[WARNING] ScriptModules cannot be moved to a GPU device yet. Running strictly on CPU for now.')
74+
device = torch.device('cpu') # 'to' is not supported on TracedModules (yet)
75+
76+
if device.type == 'cuda':
77+
cudnn.benchmark = True
78+
model = torch.nn.DataParallel(model)
79+
80+
t0 = time.time()
81+
predict(model, cv2.imread(args.image))
82+
print('Time: {} seconds'.format(time.time()-t0))

pytorch/test.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.backends.cudnn as cudnn
4+
5+
import torchvision.datasets as datasets
6+
import torchvision.transforms as transforms
7+
8+
from torch.utils.data import DataLoader
9+
from torch import jit
10+
11+
import io
12+
import time
13+
import argparse
14+
15+
from vgg import VGGNet
16+
17+
# Check device
18+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19+
#device = torch.device('cpu') # 'to' is not supported on TracedModules, ref: https://github.com/pytorch/pytorch/issues/6008
20+
21+
def test(model, test_loader):
22+
#model.eval()
23+
print_freq = 10 # print every 10 batches
24+
correct = 0
25+
total = 0
26+
27+
with torch.no_grad(): # no need to track history
28+
for batch_idx, (inputs, targets) in enumerate(test_loader):
29+
inputs, targets = inputs.to(device), targets.to(device)
30+
31+
# compute output
32+
outputs = model(inputs)
33+
34+
# record prediction accuracy
35+
_, predicted = outputs.max(1)
36+
total += targets.size(0)
37+
correct += predicted.eq(targets).sum().item()
38+
39+
if batch_idx % print_freq == 0:
40+
print('Batch: %d, Acc: %.3f%% (%d/%d)' % (batch_idx+1, 100.*correct/total, correct, total))
41+
return correct, total
42+
43+
if __name__ == '__main__':
44+
parser = argparse.ArgumentParser(description='VGGNet Test Tool')
45+
parser.add_argument('mtype', type=str, choices=['pytorch', 'torch-script'], help='Model type')
46+
args = parser.parse_args()
47+
48+
# Model
49+
print('==> Building model...')
50+
if args.mtype == 'pytorch':
51+
model = VGGNet('D-DSM', num_classes=10, input_size=32) # depthwise separable
52+
# Load model
53+
print('==> Loading PyTorch model...')
54+
model.load_state_dict(torch.load('VGG16model.pth'))
55+
model.to(device)
56+
else:
57+
print('==> Loading Torch Script model...')
58+
# Load ScriptModule from io.BytesIO object
59+
with open('VGG16-traced-eval.pt', 'rb') as f:
60+
buffer = io.BytesIO(f.read())
61+
model = torch.jit.load(buffer)
62+
print('[WARNING] ScriptModules cannot be moved to a GPU device yet. Running strictly on CPU for now.')
63+
device = torch.device('cpu') # 'to' is not supported on TracedModules (yet)
64+
65+
transform_test = transforms.Compose([
66+
transforms.ToTensor(),
67+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
68+
])
69+
70+
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
71+
test_loader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
72+
73+
if device.type == 'cuda':
74+
cudnn.benchmark = True
75+
model = torch.nn.DataParallel(model)
76+
77+
t0 = time.time()
78+
correct, total = test(model, test_loader)
79+
t1 = time.time()
80+
print('Accuracy of the network on test dataset: %f (%d/%d)' % (100.*correct/total, correct, total))
81+
print('Elapsed time: {} seconds'.format(t1-t0))

pytorch/to_torch_script.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import argparse
3+
4+
from torch.jit import trace
5+
6+
from vgg import VGGNet
7+
8+
# Check device
9+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10+
print('[Device] {}'.format(device))
11+
12+
if __name__ == '__main__':
13+
parser = argparse.ArgumentParser(description='PyTorch Model to Torch Script')
14+
parser.add_argument('mode', type=str, choices=['train', 'eval'], help='Model mode')
15+
args = parser.parse_args()
16+
17+
example_input = torch.rand(1, 3, 32, 32)
18+
# TracedModule objects do not inherit the .to() or .eval() methods
19+
20+
if args.mode == 'train':
21+
print('==> Building model...')
22+
model = VGGNet('D-DSM', num_classes=10, input_size=32)
23+
#model.to(device)
24+
model.train()
25+
26+
# convert to Torch Script
27+
print('==> Tracing model...')
28+
traced_model = trace(model, example_input)
29+
30+
# save model for training
31+
traced_model.save('VGG16-traced-train.pt')
32+
else:
33+
# load "normal" pytorch trained model
34+
print('==> Building model...')
35+
model = VGGNet('D-DSM', num_classes=10, input_size=32)
36+
print('==> Loading pre-trained model...')
37+
model.load_state_dict(torch.load('VGG16model.pth', map_location=torch.device('cpu')))
38+
#model = model.to(device)
39+
model.eval()
40+
41+
# convert to Torch Script
42+
print('==> Tracing model...')
43+
traced_model = trace(model, example_input)
44+
45+
# save model for eval
46+
traced_model.save('VGG16-traced-eval.pt')

pytorch/train.py

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchvision.datasets as datasets
4+
import torchvision.transforms as transforms
5+
import torch.backends.cudnn as cudnn
6+
import numpy as np
7+
import argparse
8+
import time
9+
import io
10+
11+
from torch.utils.data.sampler import SubsetRandomSampler
12+
from torch.utils.data import Dataset, DataLoader
13+
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
14+
15+
from torch import jit
16+
17+
from vgg import VGGNet
18+
19+
# Check device
20+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21+
#device = torch.device('cpu')
22+
23+
def train(model, train_loader, criterion, optimizer, epoch):
24+
model.train()
25+
print_freq = 10 # print every 10 batches
26+
train_loss = 0
27+
correct = 0
28+
total = 0
29+
print('\nEpoch: %d' % epoch)
30+
31+
for batch_idx, (inputs, targets) in enumerate(train_loader):
32+
inputs, targets = inputs.to(device), targets.to(device)
33+
optimizer.zero_grad()
34+
35+
# compute output
36+
outputs = model(inputs)
37+
loss = criterion(outputs, targets)
38+
39+
# compute gradient and do SGD step
40+
loss.backward()
41+
optimizer.step()
42+
43+
# record loss and accuracy
44+
train_loss += loss.item()
45+
_, predicted = outputs.max(1)
46+
total += targets.size(0)
47+
correct += predicted.eq(targets).sum().item()
48+
49+
if batch_idx % print_freq == 0:
50+
print('Batch: %d, Loss: %.3f | Acc: %.3f%% (%d/%d)' % (batch_idx+1, train_loss/(batch_idx+1), 100.*correct/total, correct, total))
51+
52+
def validate(model, val_loader, criterion):
53+
model.eval()
54+
print_freq = 10 # print every 10 batches
55+
val_loss = 0.0
56+
57+
with torch.no_grad(): # no need to track history
58+
for batch_idx, (inputs, targets) in enumerate(val_loader):
59+
inputs, targets = inputs.to(device), targets.to(device)
60+
61+
# compute output
62+
outputs = model(inputs)
63+
loss = criterion(outputs, targets)
64+
65+
# record loss
66+
val_loss += loss.item()
67+
68+
if batch_idx % print_freq == 0:
69+
print('Validation on Batch: %d, Loss: %f' % (batch_idx+1, val_loss/(batch_idx+1)))
70+
return val_loss
71+
72+
if __name__ == '__main__':
73+
parser = argparse.ArgumentParser(description='VGGNet Training Tool')
74+
parser.add_argument('mtype', type=str, choices=['pytorch', 'torch-script'], help='Model type')
75+
args = parser.parse_args()
76+
# Load CIFAR10 dataset
77+
print('==> Preparing data...')
78+
transform_train = transforms.Compose([
79+
transforms.RandomCrop(32, padding=4),
80+
transforms.RandomHorizontalFlip(),
81+
transforms.ToTensor(),
82+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
83+
])
84+
85+
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
86+
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
87+
88+
# Model
89+
print('==> Building model...')
90+
#model = VGGNet('D', num_classes=10, input_size=32) # VGG16 is configuration D (refer to paper)
91+
if args.mtype == 'torch-script':
92+
# print('==> From Torch Script...')
93+
# # Load ScriptModule from io.BytesIO object
94+
# with open('VGG16-traced-train.pt', 'rb') as f:
95+
# buffer = io.BytesIO(f.read())
96+
# model = torch.jit.load(buffer)
97+
raise RuntimeError('Training is not supported on ScriptModules yet.') #https://github.com/pytorch/pytorch/issues/6008
98+
99+
else:
100+
model = VGGNet('D-DSM', num_classes=10, input_size=32) # depthwise separable
101+
model = model.to(device)
102+
103+
if device.type == 'cuda':
104+
cudnn.benchmark = True
105+
model = torch.nn.DataParallel(model)
106+
107+
# Training
108+
num_epochs = 200 # as opposed to the paper (74) because of CIFAR10 dataset
109+
lr = 0.1
110+
# define loss function (criterion) and optimizer
111+
criterion = nn.CrossEntropyLoss()
112+
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=5e-4)
113+
114+
print('==> Training...')
115+
train_time = 0
116+
#scheduler = ReduceLROnPlateau(optimizer, 'min')
117+
scheduler = StepLR(optimizer, step_size=100, gamma=0.1) # adjust lr by factor of 10 every 100 epochs
118+
for epoch in range(num_epochs):
119+
t0 = time.time()
120+
# train one epoch
121+
train(model, train_loader, criterion, optimizer, epoch)
122+
t1 = time.time() - t0
123+
print('{} seconds'.format(t1))
124+
train_time += t1
125+
126+
# validate
127+
#val_loss = validate(model, val_loader, criterion)
128+
# adjust learning rate with scheduler
129+
#scheduler.step(val_loss)
130+
scheduler.step()
131+
132+
print('==> Finished Training: {} seconds'.format(train_time))
133+
# Save trained model
134+
torch.save(model.state_dict(), 'VGG16model.pth')

0 commit comments

Comments
 (0)