Skip to content

Commit ae36e28

Browse files
committed
Fix autotune for new ray tune
1 parent efe2967 commit ae36e28

File tree

2 files changed

+63
-24
lines changed

2 files changed

+63
-24
lines changed

Diff for: src/pytorch_mppi/autotune.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def ensure_valid_value(self, value):
108108
def apply_parameter_value(self, value):
109109
"""Apply the parameter value to the underlying object"""
110110

111+
@abc.abstractmethod
112+
def attach_to_state(self, state: dict):
113+
"""Reattach/reinitialize the parameter to a new internal state. This should be similar to a call to __init__"""
114+
111115
def get_parameter_value_from_config(self, config):
112116
"""Get the serialized value of the parameter from a config dictionary, where each name is a scalar"""
113117
return config[self.name()]
@@ -118,10 +122,19 @@ def get_config_from_parameter_value(self, value):
118122

119123

120124
class MPPIParameter(TunableParameter, abc.ABC):
121-
def __init__(self, mppi: MPPI):
125+
def __init__(self, mppi: MPPI, dim=None):
122126
self.mppi = mppi
123-
self.d = mppi.d
124-
self.dtype = mppi.dtype
127+
self._dim = dim
128+
if self.mppi is not None:
129+
self.d = self.mppi.d
130+
self.dtype = self.mppi.dtype
131+
if dim is None:
132+
self._dim = self.mppi.nu
133+
134+
def attach_to_state(self, state: dict):
135+
self.mppi = state['mppi']
136+
self.d = self.mppi.d
137+
self.dtype = self.mppi.dtype
125138

126139

127140
class SigmaParameter(MPPIParameter):
@@ -132,10 +145,10 @@ def name():
132145
return 'sigma'
133146

134147
def dim(self):
135-
return self.mppi.nu
148+
return self._dim
136149

137150
def get_current_parameter_value(self):
138-
return torch.cat([self.mppi.noise_sigma[i][i].view(1) for i in range(self.mppi.nu)])
151+
return torch.cat([self.mppi.noise_sigma[i][i].view(1) for i in range(self.dim())])
139152

140153
def ensure_valid_value(self, value):
141154
sigma = ensure_tensor(self.d, self.dtype, value)
@@ -149,10 +162,10 @@ def apply_parameter_value(self, value):
149162
self.mppi.noise_sigma_inv = torch.inverse(self.mppi.noise_sigma.detach())
150163

151164
def get_parameter_value_from_config(self, config):
152-
return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.mppi.nu)], dtype=self.dtype, device=self.d)
165+
return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.dim())], dtype=self.dtype, device=self.d)
153166

154167
def get_config_from_parameter_value(self, value):
155-
return {f'{self.name()}{i}': value[i].item() for i in range(self.mppi.nu)}
168+
return {f'{self.name()}{i}': value[i].item() for i in range(self.dim())}
156169

157170

158171
class MuParameter(MPPIParameter):
@@ -161,7 +174,7 @@ def name():
161174
return 'mu'
162175

163176
def dim(self):
164-
return self.mppi.nu
177+
return self._dim
165178

166179
def get_current_parameter_value(self):
167180
return self.mppi.noise_mu.clone()
@@ -176,10 +189,10 @@ def apply_parameter_value(self, value):
176189
self.mppi.noise_sigma_inv = torch.inverse(self.mppi.noise_sigma.detach())
177190

178191
def get_parameter_value_from_config(self, config):
179-
return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.mppi.nu)], dtype=self.dtype, device=self.d)
192+
return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.dim())], dtype=self.dtype, device=self.d)
180193

181194
def get_config_from_parameter_value(self, value):
182-
return {f'{self.name()}{i}': value[i].item() for i in range(self.mppi.nu)}
195+
return {f'{self.name()}{i}': value[i].item() for i in range(self.dim())}
183196

184197

185198
class LambdaParameter(MPPIParameter):
@@ -236,15 +249,25 @@ class Autotune:
236249
eps = 0.0001
237250

238251
def __init__(self, params_to_tune: typing.Sequence[TunableParameter],
239-
evaluate_fn: typing.Callable[[], EvaluationResult], optimizer=CMAESOpt()):
252+
evaluate_fn: typing.Callable[[], EvaluationResult],
253+
reload_state_fn: typing.Callable[[], dict] = None,
254+
optimizer=CMAESOpt()):
255+
"""
256+
257+
:param params_to_tune: sequence of tunable parameters
258+
:param evaluate_fn: function that returns an EvaluationResult that we want to minimize
259+
:param reload_state_fn: function that returns a dictionary of state to reattach to the parameters
260+
:param optimizer: optimizer that searches in the parameter space
261+
"""
240262
self.evaluate_fn = evaluate_fn
263+
self.reload_state_fn = reload_state_fn
241264

242265
self.params = params_to_tune
243266
self.optim = optimizer
244267
self.optim.tuner = self
245268
self.results = []
246269

247-
self.get_parameter_values(self.params)
270+
self.attach_parameters()
248271
self.optim.setup_optimization()
249272

250273
def optimize_step(self) -> EvaluationResult:
@@ -303,6 +326,17 @@ def apply_parameters(self, param_values):
303326
for p in self.params:
304327
p.apply_parameter_value(param_values[p.name()])
305328

329+
def attach_parameters(self):
330+
"""Attach parameters to any underlying state they require In most cases the parameters are defined already
331+
attached to whatever state it needs, e.g. the MPPI controller object for changing the parameter values.
332+
However, there are cases where the full state is not serializable, e.g. when using a multiprocessing pool
333+
and so we pass only the information required to load the state. We then must load the state and reattach
334+
the parameters to the state each training iteration."""
335+
if self.reload_state_fn is not None:
336+
state = self.reload_state_fn()
337+
for p in self.params:
338+
p.attach_to_state(state)
339+
306340
def config_to_params(self, config):
307341
"""Configs are param dictionaries where each must be a scalar"""
308342
return {p.name(): p.get_parameter_value_from_config(config) for p in self.params}

Diff for: src/pytorch_mppi/autotune_global.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import abc
22
import numpy as np
3+
import torch.cuda
34

45
# pip install "ray[tune]" bayesian-optimization hyperopt
56
from ray import tune
7+
from ray import train
68

79
from pytorch_mppi import autotune
810
from ray.tune.search.hyperopt import HyperOptSearch
@@ -47,35 +49,35 @@ def _linearize_space_value(space, v):
4749

4850

4951
class SigmaGlobalParameter(autotune.SigmaParameter, GlobalTunableParameter):
50-
def __init__(self, *args, search_space=tune.loguniform(1e-4, 1e2)):
51-
super().__init__(*args)
52+
def __init__(self, *args, search_space=tune.loguniform(1e-4, 1e2), **kwargs):
53+
super().__init__(*args, **kwargs)
5254
GlobalTunableParameter.__init__(self, search_space)
5355

5456
def total_search_space(self) -> dict:
55-
return {f"{self.name()}{i}": self.search_space for i in range(self.mppi.nu)}
57+
return {f"{self.name()}{i}": self.search_space for i in range(self.dim())}
5658

5759

5860
class MuGlobalParameter(autotune.MuParameter, GlobalTunableParameter):
59-
def __init__(self, *args, search_space=tune.uniform(-1, 1)):
60-
super().__init__(*args)
61+
def __init__(self, *args, search_space=tune.uniform(-1, 1), **kwargs):
62+
super().__init__(*args, **kwargs)
6163
GlobalTunableParameter.__init__(self, search_space)
6264

6365
def total_search_space(self) -> dict:
64-
return {f"{self.name()}{i}": self.search_space for i in range(self.mppi.nu)}
66+
return {f"{self.name()}{i}": self.search_space for i in range(self.dim())}
6567

6668

6769
class LambdaGlobalParameter(autotune.LambdaParameter, GlobalTunableParameter):
68-
def __init__(self, *args, search_space=tune.loguniform(1e-5, 1e3)):
69-
super().__init__(*args)
70+
def __init__(self, *args, search_space=tune.loguniform(1e-5, 1e3), **kwargs):
71+
super().__init__(*args, **kwargs)
7072
GlobalTunableParameter.__init__(self, search_space)
7173

7274
def total_search_space(self) -> dict:
7375
return {self.name(): self.search_space}
7476

7577

7678
class HorizonGlobalParameter(autotune.HorizonParameter, GlobalTunableParameter):
77-
def __init__(self, *args, search_space=tune.randint(1, 50)):
78-
super().__init__(*args)
79+
def __init__(self, *args, search_space=tune.randint(1, 50), **kwargs):
80+
super().__init__(*args, **kwargs)
7981
GlobalTunableParameter.__init__(self, search_space)
8082

8183
def total_search_space(self) -> dict:
@@ -124,8 +126,10 @@ def setup_optimization(self):
124126
init = self.tuner.initial_value()
125127

126128
hyperopt_search = self.search_alg(points_to_evaluate=[init], metric="cost", mode="min")
129+
130+
trainable_with_resources = tune.with_resources(self.trainable, {"gpu": 1 if torch.cuda.is_available() else 0})
127131
self.optim = tune.Tuner(
128-
self.trainable,
132+
trainable_with_resources,
129133
tune_config=tune.TuneConfig(
130134
num_samples=self.iterations,
131135
search_alg=hyperopt_search,
@@ -136,9 +140,10 @@ def setup_optimization(self):
136140
)
137141

138142
def trainable(self, config):
143+
self.tuner.attach_parameters()
139144
self.tuner.apply_parameters(self.tuner.config_to_params(config))
140145
res = self.tuner.evaluate_fn()
141-
tune.report(cost=res.costs.mean().item())
146+
train.report({'cost': res.costs.mean().item()})
142147

143148
def optimize_step(self):
144149
raise RuntimeError("Ray optimizers only allow tuning of all iterations at once")

0 commit comments

Comments
 (0)