Skip to content

Commit 0bf528c

Browse files
committed
update code
updat code to release version
1 parent 1931428 commit 0bf528c

File tree

3 files changed

+166
-221
lines changed

3 files changed

+166
-221
lines changed

plot_utils.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def _set_latent_vectors(self):
7777
# z = z.reshape([-1, 2])
7878

7979
# borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py
80-
#z = np.rollaxis(np.mgrid[self.z_range:-self.z_range:self.n_img_y * 1j, self.z_range:-self.z_range:self.n_img_x * 1j], 0, 3)
81-
z1 = np.rollaxis(np.mgrid[1:-1:self.n_img_y * 1j, 1:-1:self.n_img_x * 1j], 0, 3)
82-
z = z1**2
83-
z[z1<0] *= -1
84-
85-
z = z*self.z_range
80+
z = np.rollaxis(np.mgrid[self.z_range:-self.z_range:self.n_img_y * 1j, self.z_range:-self.z_range:self.n_img_x * 1j], 0, 3)
81+
# z1 = np.rollaxis(np.mgrid[1:-1:self.n_img_y * 1j, 1:-1:self.n_img_x * 1j], 0, 3)
82+
# z = z1**2
83+
# z[z1<0] *= -1
84+
#
85+
# z = z*self.z_range
8686

8787
self.z = z.reshape([-1, 2])
8888

@@ -110,8 +110,25 @@ def _merge(self, images, size):
110110

111111
# borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb
112112
def save_scattered_image(self, z, id, name='scattered_image.jpg'):
113+
N = 10
113114
plt.figure(figsize=(8, 6))
114-
plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1))
115-
plt.colorbar()
115+
plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
116+
plt.colorbar(ticks=range(N))
117+
axes = plt.gca()
118+
axes.set_xlim([-self.z_range-2, self.z_range+2])
119+
axes.set_ylim([-self.z_range-2, self.z_range+2])
116120
plt.grid(True)
117-
plt.savefig(self.DIR + "/" + name)
121+
plt.savefig(self.DIR + "/" + name)
122+
123+
# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
124+
def discrete_cmap(N, base_cmap=None):
125+
"""Create an N-bin discrete colormap from the specified input map"""
126+
127+
# Note that if base_cmap is a string or None, you can simply do
128+
# return plt.cm.get_cmap(base_cmap, N)
129+
# The following works for string, None, or a colormap instance:
130+
131+
base = plt.cm.get_cmap(base_cmap)
132+
color_list = base(np.linspace(0, 1, N))
133+
cmap_name = base.name + str(N)
134+
return base.from_list(cmap_name, color_list, N)

0 commit comments

Comments
 (0)