|
| 1 | +import os |
| 2 | +import pickle |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +import torch.optim as optim |
| 7 | +from torch.utils.data import DataLoader |
| 8 | +from torch.utils.data import sampler |
| 9 | +from collections import OrderedDict |
| 10 | +import torchvision.datasets as dset |
| 11 | +import torchvision.transforms as T |
| 12 | + |
| 13 | +import random |
| 14 | +import numpy as np |
| 15 | +from scipy.ndimage.filters import gaussian_filter1d |
| 16 | + |
| 17 | +SQUEEZENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float) |
| 18 | +SQUEEZENET_STD = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float) |
| 19 | + |
| 20 | +### Helper Functions |
| 21 | +''' |
| 22 | +Our pretrained model was trained on images that had been preprocessed by subtracting |
| 23 | +the per-color mean and dividing by the per-color standard deviation. We define a few helper |
| 24 | +functions for performing and undoing this preprocessing. |
| 25 | +''' |
| 26 | +def preprocess(img, size=224): |
| 27 | + transform = T.Compose([ |
| 28 | + T.Resize(size), |
| 29 | + T.ToTensor(), |
| 30 | + T.Normalize(mean=SQUEEZENET_MEAN.tolist(), |
| 31 | + std=SQUEEZENET_STD.tolist()), |
| 32 | + T.Lambda(lambda x: x[None]), |
| 33 | + ]) |
| 34 | + return transform(img) |
| 35 | + |
| 36 | +def deprocess(img, should_rescale=True): |
| 37 | + # should_rescale true for style transfer |
| 38 | + transform = T.Compose([ |
| 39 | + T.Lambda(lambda x: x[0]), |
| 40 | + T.Normalize(mean=[0, 0, 0], std=(1.0 / SQUEEZENET_STD).tolist()), |
| 41 | + T.Normalize(mean=(-SQUEEZENET_MEAN).tolist(), std=[1, 1, 1]), |
| 42 | + T.Lambda(rescale) if should_rescale else T.Lambda(lambda x: x), |
| 43 | + T.ToPILImage(), |
| 44 | + ]) |
| 45 | + return transform(img) |
| 46 | + |
| 47 | +# def deprocess(img): |
| 48 | +# transform = T.Compose([ |
| 49 | +# T.Lambda(lambda x: x[0]), |
| 50 | +# T.Normalize(mean=[0, 0, 0], std=[1.0 / s for s in SQUEEZENET_STD.tolist()]), |
| 51 | +# T.Normalize(mean=[-m for m in SQUEEZENET_MEAN.tolist()], std=[1, 1, 1]), |
| 52 | +# T.Lambda(rescale), |
| 53 | +# T.ToPILImage(), |
| 54 | +# ]) |
| 55 | +# return transform(img) |
| 56 | + |
| 57 | +def rescale(x): |
| 58 | + low, high = x.min(), x.max() |
| 59 | + x_rescaled = (x - low) / (high - low) |
| 60 | + return x_rescaled |
| 61 | + |
| 62 | +def blur_image(X, sigma=1): |
| 63 | + X_np = X.cpu().clone().numpy() |
| 64 | + X_np = gaussian_filter1d(X_np, sigma, axis=2) |
| 65 | + X_np = gaussian_filter1d(X_np, sigma, axis=3) |
| 66 | + X.copy_(torch.Tensor(X_np).type_as(X)) |
| 67 | + return X |
| 68 | + |
| 69 | + |
| 70 | +# Older versions of scipy.misc.imresize yield different results |
| 71 | +# from newer versions, so we check to make sure scipy is up to date. |
| 72 | +def check_scipy(): |
| 73 | + import scipy |
| 74 | + vnum = int(scipy.__version__.split('.')[1]) |
| 75 | + major_vnum = int(scipy.__version__.split('.')[0]) |
| 76 | + |
| 77 | + assert vnum >= 16 or major_vnum >= 1, "You must install SciPy >= 0.16.0 to complete this notebook." |
| 78 | + |
| 79 | +def jitter(X, ox, oy): |
| 80 | + """ |
| 81 | + Helper function to randomly jitter an image. |
| 82 | + |
| 83 | + Inputs |
| 84 | + - X: PyTorch Tensor of shape (N, C, H, W) |
| 85 | + - ox, oy: Integers giving number of pixels to jitter along W and H axes |
| 86 | + |
| 87 | + Returns: A new PyTorch Tensor of shape (N, C, H, W) |
| 88 | + """ |
| 89 | + if ox != 0: |
| 90 | + left = X[:, :, :, :-ox] |
| 91 | + right = X[:, :, :, -ox:] |
| 92 | + X = torch.cat([right, left], dim=3) |
| 93 | + if oy != 0: |
| 94 | + top = X[:, :, :-oy] |
| 95 | + bottom = X[:, :, -oy:] |
| 96 | + X = torch.cat([bottom, top], dim=2) |
| 97 | + return X |
| 98 | + |
| 99 | + |
| 100 | +def load_CIFAR(path='./datasets/'): |
| 101 | + NUM_TRAIN = 49000 |
| 102 | + # The torchvision.transforms package provides tools for preprocessing data |
| 103 | + # and for performing data augmentation; here we set up a transform to |
| 104 | + # preprocess the data by subtracting the mean RGB value and dividing by the |
| 105 | + # standard deviation of each RGB value; we've hardcoded the mean and std. |
| 106 | + transform = T.Compose([ |
| 107 | + T.ToTensor(), |
| 108 | + T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) |
| 109 | + ]) |
| 110 | + |
| 111 | + # We set up a Dataset object for each split (train / val / test); Datasets load |
| 112 | + # training examples one at a time, so we wrap each Dataset in a DataLoader which |
| 113 | + # iterates through the Dataset and forms minibatches. We divide the CIFAR-10 |
| 114 | + # training set into train and val sets by passing a Sampler object to the |
| 115 | + # DataLoader telling how it should sample from the underlying Dataset. |
| 116 | + cifar10_train = dset.CIFAR10(path, train=True, download=True, |
| 117 | + transform=transform) |
| 118 | + loader_train = DataLoader(cifar10_train, batch_size=64, |
| 119 | + sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN))) |
| 120 | + |
| 121 | + cifar10_val = dset.CIFAR10(path, train=True, download=True, |
| 122 | + transform=transform) |
| 123 | + loader_val = DataLoader(cifar10_val, batch_size=64, |
| 124 | + sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000))) |
| 125 | + |
| 126 | + cifar10_test = dset.CIFAR10(path, train=False, download=True, |
| 127 | + transform=transform) |
| 128 | + loader_test = DataLoader(cifar10_test, batch_size=64) |
| 129 | + return loader_train, loader_val, loader_test |
| 130 | + |
| 131 | + |
| 132 | +def load_imagenet_val(num=None, path='./datasets/imagenet_val_25.npz'): |
| 133 | + """Load a handful of validation images from ImageNet. |
| 134 | + Inputs: |
| 135 | + - num: Number of images to load (max of 25) |
| 136 | + Returns: |
| 137 | + - X: numpy array with shape [num, 224, 224, 3] |
| 138 | + - y: numpy array of integer image labels, shape [num] |
| 139 | + - class_names: dict mapping integer label to class name |
| 140 | + """ |
| 141 | + imagenet_fn = os.path.join(path) |
| 142 | + if not os.path.isfile(imagenet_fn): |
| 143 | + print('file %s not found' % imagenet_fn) |
| 144 | + print('Run the above cell to download the data') |
| 145 | + assert False, 'Need to download imagenet_val_25.npz' |
| 146 | + f = np.load(imagenet_fn, allow_pickle=True) |
| 147 | + X = f['X'] |
| 148 | + y = f['y'] |
| 149 | + class_names = f['label_map'].item() |
| 150 | + if num is not None: |
| 151 | + X = X[:num] |
| 152 | + y = y[:num] |
| 153 | + return X, y, class_names |
| 154 | + |
| 155 | + |
| 156 | +def load_COCO(path = './datasets/coco.pt'): |
| 157 | + ''' |
| 158 | + Download and load serialized COCO data from coco.pt |
| 159 | + It contains a dictionary of |
| 160 | + "train_images" - resized training images (112x112) |
| 161 | + "val_images" - resized validation images (112x112) |
| 162 | + "train_captions" - tokenized and numericalized training captions |
| 163 | + "val_captions" - tokenized and numericalized validation captions |
| 164 | + "vocab" - caption vocabulary, including "idx_to_token" and "token_to_idx" |
| 165 | +
|
| 166 | + Returns: a data dictionary |
| 167 | + ''' |
| 168 | + data_dict = torch.load(path) |
| 169 | + # print out all the keys and values from the data dictionary |
| 170 | + for k, v in data_dict.items(): |
| 171 | + if type(v) == torch.Tensor: |
| 172 | + print(k, type(v), v.shape, v.dtype) |
| 173 | + else: |
| 174 | + print(k, type(v), v.keys()) |
| 175 | + |
| 176 | + num_train = data_dict['train_images'].size(0) |
| 177 | + num_val = data_dict['val_images'].size(0) |
| 178 | + assert data_dict['train_images'].size(0) == data_dict['train_captions'].size(0) and \ |
| 179 | + data_dict['val_images'].size(0) == data_dict['val_captions'].size(0), \ |
| 180 | + 'shapes of data mismatch!' |
| 181 | + |
| 182 | + print('\nTrain images shape: ', data_dict['train_images'].shape) |
| 183 | + print('Train caption tokens shape: ', data_dict['train_captions'].shape) |
| 184 | + print('Validation images shape: ', data_dict['val_images'].shape) |
| 185 | + print('Validation caption tokens shape: ', data_dict['val_captions'].shape) |
| 186 | + print('total number of caption tokens: ', len(data_dict['vocab']['idx_to_token'])) |
| 187 | + print('mappings (list) from index to caption token: ', data_dict['vocab']['idx_to_token']) |
| 188 | + print('mappings (dict) from caption token to index: ', data_dict['vocab']['token_to_idx']) |
| 189 | + |
| 190 | + |
| 191 | + return data_dict |
| 192 | + |
| 193 | + |
| 194 | +## Dump files for submission |
| 195 | +def dump_results(submission, path): |
| 196 | + ''' |
| 197 | + Dumps a dictionary as a .pkl file for autograder |
| 198 | + results: a dictionary |
| 199 | + path: path for saving the dict object |
| 200 | + ''' |
| 201 | + # del submission['rnn_model'] |
| 202 | + # del submission['lstm_model'] |
| 203 | + # del submission['attn_model'] |
| 204 | + with open(path, "wb") as f: |
| 205 | + pickle.dump(submission, f) |
| 206 | + |
| 207 | + |
| 208 | + |
| 209 | + |
| 210 | + |
| 211 | + |
| 212 | + |
| 213 | + |
| 214 | + |
0 commit comments