@@ -77,12 +77,12 @@ def _set_latent_vectors(self):
77
77
# z = z.reshape([-1, 2])
78
78
79
79
# borrowed from https://github.com/fastforwardlabs/vae-tf/blob/master/plot.py
80
- # z = np.rollaxis(np.mgrid[self.z_range:-self.z_range:self.n_img_y * 1j, self.z_range:-self.z_range:self.n_img_x * 1j], 0, 3)
81
- z1 = np .rollaxis (np .mgrid [1 :- 1 :self .n_img_y * 1j , 1 :- 1 :self .n_img_x * 1j ], 0 , 3 )
82
- z = z1 ** 2
83
- z [z1 < 0 ] *= - 1
84
-
85
- z = z * self .z_range
80
+ z = np .rollaxis (np .mgrid [self .z_range :- self .z_range :self .n_img_y * 1j , self .z_range :- self .z_range :self .n_img_x * 1j ], 0 , 3 )
81
+ # z1 = np.rollaxis(np.mgrid[1:-1:self.n_img_y * 1j, 1:-1:self.n_img_x * 1j], 0, 3)
82
+ # z = z1**2
83
+ # z[z1<0] *= -1
84
+ #
85
+ # z = z*self.z_range
86
86
87
87
self .z = z .reshape ([- 1 , 2 ])
88
88
@@ -110,8 +110,25 @@ def _merge(self, images, size):
110
110
111
111
# borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb
112
112
def save_scattered_image (self , z , id , name = 'scattered_image.jpg' ):
113
+ N = 10
113
114
plt .figure (figsize = (8 , 6 ))
114
- plt .scatter (z [:, 0 ], z [:, 1 ], c = np .argmax (id , 1 ))
115
- plt .colorbar ()
115
+ plt .scatter (z [:, 0 ], z [:, 1 ], c = np .argmax (id , 1 ), marker = 'o' , edgecolor = 'none' , cmap = discrete_cmap (N , 'jet' ))
116
+ plt .colorbar (ticks = range (N ))
117
+ axes = plt .gca ()
118
+ axes .set_xlim ([- self .z_range - 2 , self .z_range + 2 ])
119
+ axes .set_ylim ([- self .z_range - 2 , self .z_range + 2 ])
116
120
plt .grid (True )
117
- plt .savefig (self .DIR + "/" + name )
121
+ plt .savefig (self .DIR + "/" + name )
122
+
123
+ # borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
124
+ def discrete_cmap (N , base_cmap = None ):
125
+ """Create an N-bin discrete colormap from the specified input map"""
126
+
127
+ # Note that if base_cmap is a string or None, you can simply do
128
+ # return plt.cm.get_cmap(base_cmap, N)
129
+ # The following works for string, None, or a colormap instance:
130
+
131
+ base = plt .cm .get_cmap (base_cmap )
132
+ color_list = base (np .linspace (0 , 1 , N ))
133
+ cmap_name = base .name + str (N )
134
+ return base .from_list (cmap_name , color_list , N )
0 commit comments