|
| 1 | +# Some code was borrowed from https://github.com/petewarden/tensorflow_makefile/blob/master/tensorflow/models/image/mnist/convolutional.py |
| 2 | + |
| 3 | +from __future__ import absolute_import |
| 4 | +from __future__ import division |
| 5 | +from __future__ import print_function |
| 6 | + |
| 7 | +import gzip |
| 8 | +import os |
| 9 | + |
| 10 | +import numpy |
| 11 | +from scipy import ndimage |
| 12 | + |
| 13 | +from six.moves import urllib |
| 14 | + |
| 15 | +import tensorflow as tf |
| 16 | + |
| 17 | +SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' |
| 18 | +DATA_DIRECTORY = "data" |
| 19 | + |
| 20 | +# Params for MNIST |
| 21 | +IMAGE_SIZE = 28 |
| 22 | +NUM_CHANNELS = 1 |
| 23 | +PIXEL_DEPTH = 255 |
| 24 | +NUM_LABELS = 10 |
| 25 | +VALIDATION_SIZE = 5000 # Size of the validation set. |
| 26 | + |
| 27 | +# Download MNIST data |
| 28 | +def maybe_download(filename): |
| 29 | + """Download the data from Yann's website, unless it's already here.""" |
| 30 | + if not tf.gfile.Exists(DATA_DIRECTORY): |
| 31 | + tf.gfile.MakeDirs(DATA_DIRECTORY) |
| 32 | + filepath = os.path.join(DATA_DIRECTORY, filename) |
| 33 | + if not tf.gfile.Exists(filepath): |
| 34 | + filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) |
| 35 | + with tf.gfile.GFile(filepath) as f: |
| 36 | + size = f.size() |
| 37 | + print('Successfully downloaded', filename, size, 'bytes.') |
| 38 | + return filepath |
| 39 | + |
| 40 | +# Extract the images |
| 41 | +def extract_data(filename, num_images, norm_shift=False, norm_scale=True): |
| 42 | + """Extract the images into a 4D tensor [image index, y, x, channels]. |
| 43 | + Values are rescaled from [0, 255] down to [-0.5, 0.5]. |
| 44 | + """ |
| 45 | + print('Extracting', filename) |
| 46 | + with gzip.open(filename) as bytestream: |
| 47 | + bytestream.read(16) |
| 48 | + buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS) |
| 49 | + data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32) |
| 50 | + if norm_shift: |
| 51 | + data = data - (PIXEL_DEPTH / 2.0) |
| 52 | + if norm_scale: |
| 53 | + data = data / PIXEL_DEPTH |
| 54 | + data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS) |
| 55 | + data = numpy.reshape(data, [num_images, -1]) |
| 56 | + return data |
| 57 | + |
| 58 | +# Extract the labels |
| 59 | +def extract_labels(filename, num_images): |
| 60 | + """Extract the labels into a vector of int64 label IDs.""" |
| 61 | + print('Extracting', filename) |
| 62 | + with gzip.open(filename) as bytestream: |
| 63 | + bytestream.read(8) |
| 64 | + buf = bytestream.read(1 * num_images) |
| 65 | + labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64) |
| 66 | + num_labels_data = len(labels) |
| 67 | + one_hot_encoding = numpy.zeros((num_labels_data,NUM_LABELS)) |
| 68 | + one_hot_encoding[numpy.arange(num_labels_data),labels] = 1 |
| 69 | + one_hot_encoding = numpy.reshape(one_hot_encoding, [-1, NUM_LABELS]) |
| 70 | + return one_hot_encoding |
| 71 | + |
| 72 | +# Augment training data |
| 73 | +def expend_training_data(images, labels): |
| 74 | + |
| 75 | + expanded_images = [] |
| 76 | + expanded_labels = [] |
| 77 | + |
| 78 | + j = 0 # counter |
| 79 | + for x, y in zip(images, labels): |
| 80 | + j = j+1 |
| 81 | + if j%100==0: |
| 82 | + print ('expanding data : %03d / %03d' % (j,numpy.size(images,0))) |
| 83 | + |
| 84 | + # register original data |
| 85 | + expanded_images.append(x) |
| 86 | + expanded_labels.append(y) |
| 87 | + |
| 88 | + # get a value for the background |
| 89 | + # zero is the expected value, but median() is used to estimate background's value |
| 90 | + bg_value = numpy.median(x) # this is regarded as background's value |
| 91 | + image = numpy.reshape(x, (-1, 28)) |
| 92 | + |
| 93 | + for i in range(4): |
| 94 | + # rotate the image with random degree |
| 95 | + angle = numpy.random.randint(-15,15,1) |
| 96 | + new_img = ndimage.rotate(image,angle,reshape=False, cval=bg_value) |
| 97 | + |
| 98 | + # shift the image with random distance |
| 99 | + shift = numpy.random.randint(-2, 2, 2) |
| 100 | + new_img_ = ndimage.shift(new_img,shift, cval=bg_value) |
| 101 | + |
| 102 | + # register new training data |
| 103 | + expanded_images.append(numpy.reshape(new_img_, 784)) |
| 104 | + expanded_labels.append(y) |
| 105 | + |
| 106 | + # images and labels are concatenated for random-shuffle at each epoch |
| 107 | + # notice that pair of image and label should not be broken |
| 108 | + expanded_train_total_data = numpy.concatenate((expanded_images, expanded_labels), axis=1) |
| 109 | + numpy.random.shuffle(expanded_train_total_data) |
| 110 | + |
| 111 | + return expanded_train_total_data |
| 112 | + |
| 113 | +# Prepare MNISt data |
| 114 | +def prepare_MNIST_data(use_norm_shift=False, use_norm_scale=True, use_data_augmentation=False): |
| 115 | + # Get the data. |
| 116 | + train_data_filename = maybe_download('train-images-idx3-ubyte.gz') |
| 117 | + train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz') |
| 118 | + test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz') |
| 119 | + test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz') |
| 120 | + |
| 121 | + # Extract it into numpy arrays. |
| 122 | + train_data = extract_data(train_data_filename, 60000, use_norm_shift, use_norm_scale) |
| 123 | + train_labels = extract_labels(train_labels_filename, 60000) |
| 124 | + test_data = extract_data(test_data_filename, 10000, use_norm_shift, use_norm_scale) |
| 125 | + test_labels = extract_labels(test_labels_filename, 10000) |
| 126 | + |
| 127 | + # Generate a validation set. |
| 128 | + validation_data = train_data[:VALIDATION_SIZE, :] |
| 129 | + validation_labels = train_labels[:VALIDATION_SIZE,:] |
| 130 | + train_data = train_data[VALIDATION_SIZE:, :] |
| 131 | + train_labels = train_labels[VALIDATION_SIZE:,:] |
| 132 | + |
| 133 | + # Concatenate train_data & train_labels for random shuffle |
| 134 | + if use_data_augmentation: |
| 135 | + train_total_data = expend_training_data(train_data, train_labels) |
| 136 | + else: |
| 137 | + train_total_data = numpy.concatenate((train_data, train_labels), axis=1) |
| 138 | + |
| 139 | + train_size = train_total_data.shape[0] |
| 140 | + |
| 141 | + return train_total_data, train_size, validation_data, validation_labels, test_data, test_labels |
0 commit comments