Skip to content

Commit c1b0857

Browse files
committed
3D plot
1 parent 26a1dea commit c1b0857

File tree

7 files changed

+60
-35
lines changed

7 files changed

+60
-35
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@ All methods start at the same location, specified by two variables. Both x and y
2121
For an overview of each gradient descent optimization algorithms, visit [this helpful resource](http://ruder.io/optimizing-gradient-descent/).
2222

2323
#### Numbers in figure legend indicate learning rate, specific to each Optimizer.
24+
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie11.gif)
2425
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie5.gif)
2526

2627
#### Note the optimizers' behavior when gradient is steep.
28+
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie9.gif)
29+
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie10.gif)
2730
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie7.gif)
2831
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie6.gif)
2932

3033
#### Note the optimizers' behavior when initial gradient is miniscule.
34+
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie12.gif)
3135
![](https://github.com/Jaewan-Yun/optimizer-visualization/blob/master/figures/movie8.gif)
3236

3337
<!-- ## Additional Figures

figures/movie10.gif

7.08 MB
Loading

figures/movie11.gif

8.08 MB
Loading

figures/movie12.gif

4.76 MB
Loading

figures/movie9.gif

6.91 MB
Loading

gif_maker.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import imageio
22

33
images = []
4-
filenames = ['figures/' + str(i) + '.png' for i in range(120)]
4+
filenames = ['figures/' + str(i) + '.png' for i in range(65)]
5+
6+
imageio.plugins.freeimage.download()
57

68
for filename in filenames:
79
images.append(imageio.imread(filename))
8-
imageio.mimsave('figures/movie.gif', images)
10+
imageio.mimsave('figures/movie.gif', images, format='GIF-FI', duration=0.001)

optimizer_visualization.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from matplotlib import cm
2+
from mpl_toolkits.mplot3d import Axes3D
3+
14
import matplotlib.pyplot as plt
25
import numpy as np
36
import tensorflow as tf
@@ -22,10 +25,10 @@ def cost_func(x=None, y=None):
2225
y = tf.placeholder(tf.float32, shape=[None, 1])
2326

2427
# two local minima near (0, 0)
25-
z = __f1(x, y)
28+
# z = __f1(x, y)
2629

2730
# 3rd local minimum at (-0.5, -0.8)
28-
z -= __f2(x, y, x_mean=-0.5, y_mean=-0.8, x_sig=0.35, y_sig=0.35)
31+
z = -1 * __f2(x, y, x_mean=-0.5, y_mean=-0.8, x_sig=0.35, y_sig=0.35)
2932

3033
# one steep gaussian trench at (0, 0)
3134
# z -= __f2(x, y, x_mean=0, y_mean=0, x_sig=0.2, y_sig=0.2)
@@ -53,7 +56,8 @@ def __f2(x, y, x_mean, y_mean, x_sig, y_sig):
5356

5457
# pyplot settings
5558
plt.ion()
56-
plt.figure(figsize=(3, 2), dpi=300)
59+
fig = plt.figure(figsize=(3, 2), dpi=300)
60+
ax = fig.add_subplot(111, projection='3d')
5761
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
5862
params = {'legend.fontsize': 3,
5963
'legend.handlelength': 3}
@@ -72,7 +76,9 @@ def __f2(x, y, x_mean, y_mean, x_sig, y_sig):
7276
z_val_mesh_flat = sess.run(z, feed_dict={x: x_val_mesh_flat, y: y_val_mesh_flat})
7377
z_val_mesh = z_val_mesh_flat.reshape(x_val_mesh.shape)
7478
levels = np.arange(-10, 1, 0.05)
75-
plt.contour(x_val_mesh, y_val_mesh, z_val_mesh, levels, alpha=.7, linewidths=0.4)
79+
# ax.contour(x_val_mesh, y_val_mesh, z_val_mesh, levels, alpha=.7, linewidths=0.4)
80+
# ax.plot_wireframe(x_val_mesh, y_val_mesh, z_val_mesh, alpha=.5, linewidths=0.4, antialiased=True)
81+
ax.plot_surface(x_val_mesh, y_val_mesh, z_val_mesh, alpha=.4, cmap=cm.coolwarm)
7682
plt.draw()
7783

7884
# starting location for variables
@@ -91,58 +97,71 @@ def __f2(x, y, x_mean, y_mean, x_sig, y_sig):
9197
cost.append(cost_func(x_var[i], y_var[i])[2])
9298

9399
# define method of gradient descent for each graph
94-
ops_param = [['Adadelta', 50],
95-
['Adagrad', 0.10],
96-
['Adam', 0.05],
97-
['Ftrl', 0.5],
98-
['GD', 0.05],
99-
['Momentum', 0.01],
100-
['RMSProp', 0.02]]
100+
# optimizer label name, learning rate, color
101+
ops_param = np.array([['Adadelta', 50.0, 'b'],
102+
['Adagrad', 0.10, 'g'],
103+
['Adam', 0.05, 'r'],
104+
['Ftrl', 0.5, 'c'],
105+
['GD', 0.05, 'm'],
106+
['Momentum', 0.01, 'y'],
107+
['RMSProp', 0.02, 'k']])
101108

102109
ops = []
103-
ops.append(tf.train.AdadeltaOptimizer(ops_param[0][1]).minimize(cost[0]))
104-
ops.append(tf.train.AdagradOptimizer(ops_param[1][1]).minimize(cost[1]))
105-
ops.append(tf.train.AdamOptimizer(ops_param[2][1]).minimize(cost[2]))
106-
ops.append(tf.train.FtrlOptimizer(ops_param[3][1]).minimize(cost[3]))
107-
ops.append(tf.train.GradientDescentOptimizer(ops_param[4][1]).minimize(cost[4]))
108-
ops.append(tf.train.MomentumOptimizer(ops_param[5][1], momentum=0.95).minimize(cost[5]))
109-
ops.append(tf.train.RMSPropOptimizer(ops_param[6][1]).minimize(cost[6]))
110+
ops.append(tf.train.AdadeltaOptimizer(float(ops_param[0, 1])).minimize(cost[0]))
111+
ops.append(tf.train.AdagradOptimizer(float(ops_param[1, 1])).minimize(cost[1]))
112+
ops.append(tf.train.AdamOptimizer(float(ops_param[2, 1])).minimize(cost[2]))
113+
ops.append(tf.train.FtrlOptimizer(float(ops_param[3, 1])).minimize(cost[3]))
114+
ops.append(tf.train.GradientDescentOptimizer(float(ops_param[4, 1])).minimize(cost[4]))
115+
ops.append(tf.train.MomentumOptimizer(float(ops_param[5, 1]), momentum=0.95).minimize(cost[5]))
116+
ops.append(tf.train.RMSPropOptimizer(float(ops_param[6, 1])).minimize(cost[6]))
117+
118+
# 3d plot camera zoom, angle
119+
xlm = ax.get_xlim3d()
120+
ylm = ax.get_ylim3d()
121+
zlm = ax.get_zlim3d()
122+
ax.set_xlim3d(xlm[0] * 0.5, xlm[1] * 0.5)
123+
ax.set_ylim3d(ylm[0] * 0.5, ylm[1] * 0.5)
124+
ax.set_zlim3d(zlm[0] * 0.5, zlm[1] * 0.5)
125+
azm = ax.azim
126+
ele = ax.elev + 40
127+
ax.view_init(elev=ele, azim=azm)
110128

111129
with tf.Session() as sess:
112130
sess.run(tf.global_variables_initializer())
113131

114132
# use last location to draw a line to the current location
115-
last_x, last_y = [], []
116-
plot_cache = []
117-
for i in range(7):
118-
last_x.append(x_i)
119-
last_y.append(y_i)
120-
plot_cache.append(None)
121-
122-
# available colors for each label
123-
colors = ('b', 'g', 'r', 'c', 'm', 'y', 'k')
133+
last_x, last_y, last_z = [], [], []
134+
plot_cache = [None for _ in range(len(ops))]
124135

125136
# loop each step of the optimization algorithm
126137
steps = 1000
127138
for iter in range(steps):
128139
for i, op in enumerate(ops):
129140
# run a step of optimization and collect new x and y variable values
130-
_, x_val, y_val = sess.run([op, x_var[i], y_var[i]])
141+
_, x_val, y_val, z_val = sess.run([op, x_var[i], y_var[i], cost[i]])
131142

132143
# move dot to the current value
133144
if plot_cache[i]:
134145
plot_cache[i].remove()
135-
plot_cache[i] = plt.scatter(x_val, y_val, color=colors[i], s=3, label=ops_param[i][0])
146+
plot_cache[i] = ax.scatter(x_val, y_val, z_val, s=3, depthshade=True, label=ops_param[i, 0], color=ops_param[i, 2])
136147

137148
# draw a line from the previous value
138-
if last_x[i] and last_y[i]:
139-
plt.plot([last_x[i], x_val], [last_y[i], y_val], color=colors[i], linewidth=0.5)
149+
if iter == 0:
150+
last_z.append(z_val)
151+
last_x.append(x_i)
152+
last_y.append(y_i)
153+
ax.plot([last_x[i], x_val], [last_y[i], y_val], [last_z[i], z_val], linewidth=0.5, color=ops_param[i, 2])
140154
last_x[i] = x_val
141155
last_y[i] = y_val
156+
last_z[i] = z_val
157+
158+
if iter == 0:
159+
legend = np.vstack((ops_param[:, 0], ops_param[:, 1])).transpose()
160+
plt.legend(plot_cache, legend)
142161

143-
plt.legend(plot_cache, ops_param)
144162
plt.savefig('figures/' + str(iter) + '.png')
145163
print('iteration: {}'.format(iter))
146-
plt.pause(0.001)
164+
165+
plt.pause(0.0001)
147166

148167
print("done")

0 commit comments

Comments
 (0)