Skip to content

Commit efe2967

Browse files
committed
Fix interface for Pendulum-v1 and new gym versions
1 parent 231ad37 commit efe2967

File tree

5 files changed

+35
-33
lines changed

5 files changed

+35
-33
lines changed

Diff for: pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "pytorch_mppi"
3-
version = "0.7.2"
3+
version = "0.7.3"
44
description = "Model Predictive Path Integral (MPPI) implemented in pytorch"
55
readme = "README.md" # Optional
66

@@ -73,7 +73,7 @@ tune = [
7373
]
7474
test = [
7575
"pytest",
76-
'gym<=0.20',
76+
'gym',
7777
'pygame',
7878
'pyglet==1.5.27',
7979
'window-recorder',

Diff for: src/pytorch_mppi/mppi.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def _dynamics(self, state, u, t):
203203
def _running_cost(self, state, u, t):
204204
return self.running_cost(state, u, t) if self.step_dependency else self.running_cost(state, u)
205205

206+
def shift_nominal_trajectory(self):
207+
"""
208+
Shift the nominal trajectory forward one step
209+
"""
210+
# shift command 1 time step
211+
self.U = torch.roll(self.U, -1, dims=0)
212+
self.U[-1] = self.u_init
213+
206214
def command(self, state, shift_nominal_trajectory=True):
207215
"""
208216
:param state: (nx) or (K x nx) current state, or samples of states (for propagating a distribution of states)
@@ -211,9 +219,7 @@ def command(self, state, shift_nominal_trajectory=True):
211219
:returns action: (nu) best action
212220
"""
213221
if shift_nominal_trajectory:
214-
# shift command 1 time step
215-
self.U = torch.roll(self.U, -1, dims=0)
216-
self.U[-1] = self.u_init
222+
self.shift_nominal_trajectory()
217223

218224
return self._command(state)
219225

@@ -360,11 +366,12 @@ def run_mppi(mppi, env, retrain_dynamics, retrain_after_iter=50, iter=1000, rend
360366
dataset = torch.zeros((retrain_after_iter, mppi.nx + mppi.nu), dtype=mppi.U.dtype, device=mppi.d)
361367
total_reward = 0
362368
for i in range(iter):
363-
state = env.state.copy()
369+
state = env.unwrapped.state.copy()
364370
command_start = time.perf_counter()
365371
action = mppi.command(state)
366372
elapsed = time.perf_counter() - command_start
367-
s, r, _, _ = env.step(action.cpu().numpy())
373+
res = env.step(action.cpu().numpy())
374+
s, r = res[0], res[1]
368375
total_reward += r
369376
logger.debug("action taken: %.4f cost received: %.4f time taken: %.5fs", action, -r, elapsed)
370377
if render:

Diff for: tests/pendulum.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import math
66
from pytorch_mppi import mppi
7-
from gym import wrappers, logger as gym_log
7+
from gym import logger as gym_log
88

99
gym_log.set_level(gym_log.INFO)
1010
logger = logging.getLogger(__name__)
@@ -13,7 +13,7 @@
1313
datefmt='%m-%d %H:%M:%S')
1414

1515
if __name__ == "__main__":
16-
ENV_NAME = "Pendulum-v0"
16+
ENV_NAME = "Pendulum-v1"
1717
TIMESTEPS = 15 # T
1818
N_SAMPLES = 100 # K
1919
ACTION_LOW = -2.0
@@ -40,9 +40,9 @@ def dynamics(state, perturbed_action):
4040
u = perturbed_action
4141
u = torch.clamp(u, -2, 2)
4242

43-
newthdot = thdot + (-3 * g / (2 * l) * np.sin(th + np.pi) + 3. / (m * l ** 2) * u) * dt
43+
newthdot = thdot + (3 * g / (2 * l) * np.sin(th) + 3.0 / (m * l ** 2) * u) * dt
44+
newthdot = np.clip(newthdot, -8, 8)
4445
newth = th + newthdot * dt
45-
newthdot = torch.clamp(newthdot, -8, 8)
4646

4747
state = torch.cat((newth, newthdot), dim=1)
4848
return state
@@ -65,18 +65,15 @@ def train(new_data):
6565

6666

6767
downward_start = True
68-
env = gym.make(ENV_NAME).env # bypass the default TimeLimit wrapper
69-
env.reset()
70-
if downward_start:
71-
env.state = [np.pi, 1]
68+
env = gym.make(ENV_NAME, render_mode="human")
7269

73-
env = wrappers.Monitor(env, '/tmp/mppi/', force=True)
7470
env.reset()
7571
if downward_start:
76-
env.env.state = [np.pi, 1]
72+
env.state = env.unwrapped.state = [np.pi, 1]
7773

7874
nx = 2
7975
mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS,
80-
lambda_=lambda_)
76+
lambda_=lambda_, u_min=torch.tensor(ACTION_LOW, device=d),
77+
u_max=torch.tensor(ACTION_HIGH, device=d), device=d)
8178
total_reward = mppi.run_mppi(mppi_gym, env, train)
8279
logger.info("Total reward %f", total_reward)

Diff for: tests/pendulum_approximate.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import math
66
from pytorch_mppi import mppi
7-
from gym import wrappers, logger as gym_log
7+
from gym import logger as gym_log
88

99
gym_log.set_level(gym_log.INFO)
1010
logger = logging.getLogger(__name__)
@@ -13,7 +13,7 @@
1313
datefmt='%m-%d %H:%M:%S')
1414

1515
if __name__ == "__main__":
16-
ENV_NAME = "Pendulum-v0"
16+
ENV_NAME = "Pendulum-v1"
1717
TIMESTEPS = 30 # T
1818
N_SAMPLES = 1000 # K
1919
ACTION_LOW = -2.0
@@ -168,10 +168,10 @@ def train(new_data):
168168

169169

170170
downward_start = True
171-
env = gym.make(ENV_NAME).env # bypass the default TimeLimit wrapper
171+
env = gym.make(ENV_NAME, render_mode="human").env # bypass the default TimeLimit wrapper
172172
env.reset()
173173
if downward_start:
174-
env.state = [np.pi, 1]
174+
env.state = env.unwrapped.state = [np.pi, 1]
175175

176176
# bootstrap network with random actions
177177
if BOOT_STRAP_ITER:
@@ -188,10 +188,9 @@ def train(new_data):
188188
train(new_data)
189189
logger.info("bootstrapping finished")
190190

191-
env = wrappers.Monitor(env, '/tmp/mppi/', force=True)
192191
env.reset()
193192
if downward_start:
194-
env.env.state = [np.pi, 1]
193+
env.state = env.unwrapped.state = [np.pi, 1]
195194

196195
mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS,
197196
lambda_=lambda_, device=d, u_min=torch.tensor(ACTION_LOW, dtype=torch.double, device=d),

Diff for: tests/pendulum_approximate_continuous.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import math
1010
from pytorch_mppi import mppi
11-
from gym import wrappers, logger as gym_log
11+
from gym import logger as gym_log
1212

1313
gym_log.set_level(gym_log.INFO)
1414
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@
1717
datefmt='%m-%d %H:%M:%S')
1818

1919
if __name__ == "__main__":
20-
ENV_NAME = "Pendulum-v0"
20+
ENV_NAME = "Pendulum-v1"
2121
TIMESTEPS = 15 # T
2222
N_SAMPLES = 100 # K
2323
ACTION_LOW = -2.0
@@ -87,9 +87,9 @@ def true_dynamics(state, perturbed_action):
8787
u = perturbed_action
8888
u = torch.clamp(u, -2, 2)
8989

90-
newthdot = thdot + (-3 * g / (2 * l) * torch.sin(th + np.pi) + 3. / (m * l ** 2) * u) * dt
90+
newthdot = thdot + (3 * g / (2 * l) * torch.sin(th) + 3.0 / (m * l**2) * u) * dt
91+
newthdot = torch.clip(newthdot, -8, 8)
9192
newth = th + newthdot * dt
92-
newthdot = torch.clamp(newthdot, -8, 8)
9393

9494
state = torch.cat((newth, newthdot), dim=1)
9595
return state
@@ -176,10 +176,10 @@ def train(new_data):
176176

177177

178178
downward_start = True
179-
env = gym.make(ENV_NAME).env # bypass the default TimeLimit wrapper
179+
env = gym.make(ENV_NAME, render_mode="human").env # bypass the default TimeLimit wrapper
180180
env.reset()
181181
if downward_start:
182-
env.state = [np.pi, 1]
182+
env.state = env.unwrapped.state = [np.pi, 1]
183183

184184
# bootstrap network with random actions
185185
if BOOT_STRAP_ITER:
@@ -196,12 +196,11 @@ def train(new_data):
196196
train(new_data)
197197
logger.info("bootstrapping finished")
198198

199-
env = wrappers.Monitor(env, '/tmp/mppi/', force=True)
200199
env.reset()
201200
if downward_start:
202-
env.env.state = [np.pi, 1]
201+
env.state = env.unwrapped.state = [np.pi, 1]
203202

204-
mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS,
203+
mppi_gym = mppi.MPPI(true_dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS,
205204
lambda_=lambda_, device=d, u_min=torch.tensor(ACTION_LOW, dtype=torch.double, device=d),
206205
u_max=torch.tensor(ACTION_HIGH, dtype=torch.double, device=d))
207206
total_reward, data = mppi.run_mppi(mppi_gym, env, train)

0 commit comments

Comments
 (0)