Skip to content

Commit e092e31

Browse files
committed
first commit
first commit
1 parent 7de46d6 commit e092e31

15 files changed

+750
-64
lines changed

.gitattributes

-17
This file was deleted.

.gitignore

-47
This file was deleted.

mnist_data.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Some code was borrowed from https://github.com/petewarden/tensorflow_makefile/blob/master/tensorflow/models/image/mnist/convolutional.py
2+
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
import gzip
8+
import os
9+
10+
import numpy
11+
from scipy import ndimage
12+
13+
from six.moves import urllib
14+
15+
import tensorflow as tf
16+
17+
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
18+
DATA_DIRECTORY = "data"
19+
20+
# Params for MNIST
21+
IMAGE_SIZE = 28
22+
NUM_CHANNELS = 1
23+
PIXEL_DEPTH = 255
24+
NUM_LABELS = 10
25+
VALIDATION_SIZE = 5000 # Size of the validation set.
26+
27+
# Download MNIST data
28+
def maybe_download(filename):
29+
"""Download the data from Yann's website, unless it's already here."""
30+
if not tf.gfile.Exists(DATA_DIRECTORY):
31+
tf.gfile.MakeDirs(DATA_DIRECTORY)
32+
filepath = os.path.join(DATA_DIRECTORY, filename)
33+
if not tf.gfile.Exists(filepath):
34+
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
35+
with tf.gfile.GFile(filepath) as f:
36+
size = f.size()
37+
print('Successfully downloaded', filename, size, 'bytes.')
38+
return filepath
39+
40+
# Extract the images
41+
def extract_data(filename, num_images, norm_shift=False, norm_scale=True):
42+
"""Extract the images into a 4D tensor [image index, y, x, channels].
43+
Values are rescaled from [0, 255] down to [-0.5, 0.5].
44+
"""
45+
print('Extracting', filename)
46+
with gzip.open(filename) as bytestream:
47+
bytestream.read(16)
48+
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
49+
data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
50+
if norm_shift:
51+
data = data - (PIXEL_DEPTH / 2.0)
52+
if norm_scale:
53+
data = data / PIXEL_DEPTH
54+
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
55+
data = numpy.reshape(data, [num_images, -1])
56+
return data
57+
58+
# Extract the labels
59+
def extract_labels(filename, num_images):
60+
"""Extract the labels into a vector of int64 label IDs."""
61+
print('Extracting', filename)
62+
with gzip.open(filename) as bytestream:
63+
bytestream.read(8)
64+
buf = bytestream.read(1 * num_images)
65+
labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
66+
num_labels_data = len(labels)
67+
one_hot_encoding = numpy.zeros((num_labels_data,NUM_LABELS))
68+
one_hot_encoding[numpy.arange(num_labels_data),labels] = 1
69+
one_hot_encoding = numpy.reshape(one_hot_encoding, [-1, NUM_LABELS])
70+
return one_hot_encoding
71+
72+
# Augment training data
73+
def expend_training_data(images, labels):
74+
75+
expanded_images = []
76+
expanded_labels = []
77+
78+
j = 0 # counter
79+
for x, y in zip(images, labels):
80+
j = j+1
81+
if j%100==0:
82+
print ('expanding data : %03d / %03d' % (j,numpy.size(images,0)))
83+
84+
# register original data
85+
expanded_images.append(x)
86+
expanded_labels.append(y)
87+
88+
# get a value for the background
89+
# zero is the expected value, but median() is used to estimate background's value
90+
bg_value = numpy.median(x) # this is regarded as background's value
91+
image = numpy.reshape(x, (-1, 28))
92+
93+
for i in range(4):
94+
# rotate the image with random degree
95+
angle = numpy.random.randint(-15,15,1)
96+
new_img = ndimage.rotate(image,angle,reshape=False, cval=bg_value)
97+
98+
# shift the image with random distance
99+
shift = numpy.random.randint(-2, 2, 2)
100+
new_img_ = ndimage.shift(new_img,shift, cval=bg_value)
101+
102+
# register new training data
103+
expanded_images.append(numpy.reshape(new_img_, 784))
104+
expanded_labels.append(y)
105+
106+
# images and labels are concatenated for random-shuffle at each epoch
107+
# notice that pair of image and label should not be broken
108+
expanded_train_total_data = numpy.concatenate((expanded_images, expanded_labels), axis=1)
109+
numpy.random.shuffle(expanded_train_total_data)
110+
111+
return expanded_train_total_data
112+
113+
# Prepare MNISt data
114+
def prepare_MNIST_data(use_norm_shift=False, use_norm_scale=True, use_data_augmentation=False):
115+
# Get the data.
116+
train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
117+
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
118+
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
119+
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
120+
121+
# Extract it into numpy arrays.
122+
train_data = extract_data(train_data_filename, 60000, use_norm_shift, use_norm_scale)
123+
train_labels = extract_labels(train_labels_filename, 60000)
124+
test_data = extract_data(test_data_filename, 10000, use_norm_shift, use_norm_scale)
125+
test_labels = extract_labels(test_labels_filename, 10000)
126+
127+
# Generate a validation set.
128+
validation_data = train_data[:VALIDATION_SIZE, :]
129+
validation_labels = train_labels[:VALIDATION_SIZE,:]
130+
train_data = train_data[VALIDATION_SIZE:, :]
131+
train_labels = train_labels[VALIDATION_SIZE:,:]
132+
133+
# Concatenate train_data & train_labels for random shuffle
134+
if use_data_augmentation:
135+
train_total_data = expend_training_data(train_data, train_labels)
136+
else:
137+
train_total_data = numpy.concatenate((train_data, train_labels), axis=1)
138+
139+
train_size = train_total_data.shape[0]
140+
141+
return train_total_data, train_size, validation_data, validation_labels, test_data, test_labels

plot_utils.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from scipy.misc import imsave
4+
from scipy.misc import imresize
5+
6+
class Plot_Reproduce_Performance():
7+
def __init__(self, DIR, n_img_x=8, n_img_y=8, img_w=28, img_h=28, resize_factor=1.0):
8+
self.DIR = DIR
9+
10+
assert n_img_x > 0 and n_img_y > 0
11+
12+
self.n_img_x = n_img_x
13+
self.n_img_y = n_img_y
14+
self.n_tot_imgs = n_img_x * n_img_y
15+
16+
assert img_w > 0 and img_h > 0
17+
18+
self.img_w = img_w
19+
self.img_h = img_h
20+
21+
assert resize_factor > 0
22+
23+
self.resize_factor = resize_factor
24+
25+
def save_images(self, images, name='result.jpg'):
26+
images = images.reshape(self.n_img_x*self.n_img_y, self.img_h, self.img_w)
27+
imsave(self.DIR + "/"+name, self._merge(images, [self.n_img_y, self.n_img_x]))
28+
29+
def _merge(self, images, size):
30+
h, w = images.shape[1], images.shape[2]
31+
32+
h_ = int(h * self.resize_factor)
33+
w_ = int(w * self.resize_factor)
34+
35+
img = np.zeros((h_ * size[0], w_ * size[1]))
36+
37+
for idx, image in enumerate(images):
38+
i = int(idx % size[1])
39+
j = int(idx / size[1])
40+
41+
image_ = imresize(image, size=(w_,h_), interp='bicubic')
42+
43+
img[j*h_:j*h_+h_, i*w_:i*w_+w_] = image_
44+
45+
return img
46+
47+
class Plot_Manifold_Learning_Result():
48+
def __init__(self, DIR, n_img_x=20, n_img_y=20, img_w=28, img_h=28, resize_factor=1.0, z_range=4):
49+
self.DIR = DIR
50+
51+
assert n_img_x > 0 and n_img_y > 0
52+
53+
self.n_img_x = n_img_x
54+
self.n_img_y = n_img_y
55+
self.n_tot_imgs = n_img_x * n_img_y
56+
57+
assert img_w > 0 and img_h > 0
58+
59+
self.img_w = img_w
60+
self.img_h = img_h
61+
62+
assert resize_factor > 0
63+
64+
self.resize_factor = resize_factor
65+
66+
assert z_range > 0
67+
self.z_range = z_range
68+
69+
self._set_latent_vectors()
70+
71+
def _set_latent_vectors(self):
72+
73+
# z1 = np.linspace(-self.z_range, self.z_range, self.n_img_y)
74+
# z2 = np.linspace(-self.z_range, self.z_range, self.n_img_x)
75+
#
76+
# z = np.array(np.meshgrid(z1, z2))
77+
# z = z.reshape([-1, 2])
78+
79+
# borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py
80+
#z = np.rollaxis(np.mgrid[self.z_range:-self.z_range:self.n_img_y * 1j, self.z_range:-self.z_range:self.n_img_x * 1j], 0, 3)
81+
z1 = np.rollaxis(np.mgrid[1:-1:self.n_img_y * 1j, 1:-1:self.n_img_x * 1j], 0, 3)
82+
z = z1**2
83+
z[z1<0] *= -1
84+
85+
z = z*self.z_range
86+
87+
self.z = z.reshape([-1, 2])
88+
89+
def save_images(self, images, name='result.jpg'):
90+
images = images.reshape(self.n_img_x*self.n_img_y, self.img_h, self.img_w)
91+
imsave(self.DIR + "/"+name, self._merge(images, [self.n_img_y, self.n_img_x]))
92+
93+
def _merge(self, images, size):
94+
h, w = images.shape[1], images.shape[2]
95+
96+
h_ = int(h * self.resize_factor)
97+
w_ = int(w * self.resize_factor)
98+
99+
img = np.zeros((h_ * size[0], w_ * size[1]))
100+
101+
for idx, image in enumerate(images):
102+
i = int(idx % size[1])
103+
j = int(idx / size[1])
104+
105+
image_ = imresize(image, size=(w_, h_), interp='bicubic')
106+
107+
img[j * h_:j * h_ + h_, i * w_:i * w_ + w_] = image_
108+
109+
return img
110+
111+
# borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb
112+
def save_scattered_image(self, z, id, name='scattered_image.jpg'):
113+
plt.figure(figsize=(8, 6))
114+
plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1))
115+
plt.colorbar()
116+
plt.grid(True)
117+
plt.savefig(self.DIR + "/" + name)

results/PMLR.jpg

64.4 KB
Loading

results/PMLR_map.jpg

162 KB
Loading

results/denoising.jpg

16.1 KB
Loading

results/dim_z_10.jpg

18.1 KB
Loading

results/dim_z_2.jpg

16 KB
Loading

results/dim_z_20.jpg

18.6 KB
Loading

results/dim_z_5.jpg

17.2 KB
Loading

results/input.jpg

21.5 KB
Loading

results/input_noise.jpg

41.9 KB
Loading

0 commit comments

Comments
 (0)