|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torchvision.datasets as datasets |
| 4 | +import torchvision.transforms as transforms |
| 5 | +import torch.backends.cudnn as cudnn |
| 6 | +import numpy as np |
| 7 | +import argparse |
| 8 | +import time |
| 9 | +import io |
| 10 | + |
| 11 | +from torch.utils.data.sampler import SubsetRandomSampler |
| 12 | +from torch.utils.data import Dataset, DataLoader |
| 13 | +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR |
| 14 | + |
| 15 | +from torch import jit |
| 16 | + |
| 17 | +from vgg import VGGNet |
| 18 | + |
| 19 | +# Check device |
| 20 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 21 | +#device = torch.device('cpu') |
| 22 | + |
| 23 | +def train(model, train_loader, criterion, optimizer, epoch): |
| 24 | + model.train() |
| 25 | + print_freq = 10 # print every 10 batches |
| 26 | + train_loss = 0 |
| 27 | + correct = 0 |
| 28 | + total = 0 |
| 29 | + print('\nEpoch: %d' % epoch) |
| 30 | + |
| 31 | + for batch_idx, (inputs, targets) in enumerate(train_loader): |
| 32 | + inputs, targets = inputs.to(device), targets.to(device) |
| 33 | + optimizer.zero_grad() |
| 34 | + |
| 35 | + # compute output |
| 36 | + outputs = model(inputs) |
| 37 | + loss = criterion(outputs, targets) |
| 38 | + |
| 39 | + # compute gradient and do SGD step |
| 40 | + loss.backward() |
| 41 | + optimizer.step() |
| 42 | + |
| 43 | + # record loss and accuracy |
| 44 | + train_loss += loss.item() |
| 45 | + _, predicted = outputs.max(1) |
| 46 | + total += targets.size(0) |
| 47 | + correct += predicted.eq(targets).sum().item() |
| 48 | + |
| 49 | + if batch_idx % print_freq == 0: |
| 50 | + print('Batch: %d, Loss: %.3f | Acc: %.3f%% (%d/%d)' % (batch_idx+1, train_loss/(batch_idx+1), 100.*correct/total, correct, total)) |
| 51 | + |
| 52 | +def validate(model, val_loader, criterion): |
| 53 | + model.eval() |
| 54 | + print_freq = 10 # print every 10 batches |
| 55 | + val_loss = 0.0 |
| 56 | + |
| 57 | + with torch.no_grad(): # no need to track history |
| 58 | + for batch_idx, (inputs, targets) in enumerate(val_loader): |
| 59 | + inputs, targets = inputs.to(device), targets.to(device) |
| 60 | + |
| 61 | + # compute output |
| 62 | + outputs = model(inputs) |
| 63 | + loss = criterion(outputs, targets) |
| 64 | + |
| 65 | + # record loss |
| 66 | + val_loss += loss.item() |
| 67 | + |
| 68 | + if batch_idx % print_freq == 0: |
| 69 | + print('Validation on Batch: %d, Loss: %f' % (batch_idx+1, val_loss/(batch_idx+1))) |
| 70 | + return val_loss |
| 71 | + |
| 72 | +if __name__ == '__main__': |
| 73 | + parser = argparse.ArgumentParser(description='VGGNet Training Tool') |
| 74 | + parser.add_argument('mtype', type=str, choices=['pytorch', 'torch-script'], help='Model type') |
| 75 | + args = parser.parse_args() |
| 76 | + # Load CIFAR10 dataset |
| 77 | + print('==> Preparing data...') |
| 78 | + transform_train = transforms.Compose([ |
| 79 | + transforms.RandomCrop(32, padding=4), |
| 80 | + transforms.RandomHorizontalFlip(), |
| 81 | + transforms.ToTensor(), |
| 82 | + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), |
| 83 | + ]) |
| 84 | + |
| 85 | + trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) |
| 86 | + train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4) |
| 87 | + |
| 88 | + # Model |
| 89 | + print('==> Building model...') |
| 90 | + #model = VGGNet('D', num_classes=10, input_size=32) # VGG16 is configuration D (refer to paper) |
| 91 | + if args.mtype == 'torch-script': |
| 92 | + # print('==> From Torch Script...') |
| 93 | + # # Load ScriptModule from io.BytesIO object |
| 94 | + # with open('VGG16-traced-train.pt', 'rb') as f: |
| 95 | + # buffer = io.BytesIO(f.read()) |
| 96 | + # model = torch.jit.load(buffer) |
| 97 | + raise RuntimeError('Training is not supported on ScriptModules yet.') #https://github.com/pytorch/pytorch/issues/6008 |
| 98 | + |
| 99 | + else: |
| 100 | + model = VGGNet('D-DSM', num_classes=10, input_size=32) # depthwise separable |
| 101 | + model = model.to(device) |
| 102 | + |
| 103 | + if device.type == 'cuda': |
| 104 | + cudnn.benchmark = True |
| 105 | + model = torch.nn.DataParallel(model) |
| 106 | + |
| 107 | + # Training |
| 108 | + num_epochs = 200 # as opposed to the paper (74) because of CIFAR10 dataset |
| 109 | + lr = 0.1 |
| 110 | + # define loss function (criterion) and optimizer |
| 111 | + criterion = nn.CrossEntropyLoss() |
| 112 | + optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=5e-4) |
| 113 | + |
| 114 | + print('==> Training...') |
| 115 | + train_time = 0 |
| 116 | + #scheduler = ReduceLROnPlateau(optimizer, 'min') |
| 117 | + scheduler = StepLR(optimizer, step_size=100, gamma=0.1) # adjust lr by factor of 10 every 100 epochs |
| 118 | + for epoch in range(num_epochs): |
| 119 | + t0 = time.time() |
| 120 | + # train one epoch |
| 121 | + train(model, train_loader, criterion, optimizer, epoch) |
| 122 | + t1 = time.time() - t0 |
| 123 | + print('{} seconds'.format(t1)) |
| 124 | + train_time += t1 |
| 125 | + |
| 126 | + # validate |
| 127 | + #val_loss = validate(model, val_loader, criterion) |
| 128 | + # adjust learning rate with scheduler |
| 129 | + #scheduler.step(val_loss) |
| 130 | + scheduler.step() |
| 131 | + |
| 132 | + print('==> Finished Training: {} seconds'.format(train_time)) |
| 133 | + # Save trained model |
| 134 | + torch.save(model.state_dict(), 'VGG16model.pth') |
0 commit comments