@@ -108,6 +108,10 @@ def ensure_valid_value(self, value):
108
108
def apply_parameter_value (self , value ):
109
109
"""Apply the parameter value to the underlying object"""
110
110
111
+ @abc .abstractmethod
112
+ def attach_to_state (self , state : dict ):
113
+ """Reattach/reinitialize the parameter to a new internal state. This should be similar to a call to __init__"""
114
+
111
115
def get_parameter_value_from_config (self , config ):
112
116
"""Get the serialized value of the parameter from a config dictionary, where each name is a scalar"""
113
117
return config [self .name ()]
@@ -118,10 +122,19 @@ def get_config_from_parameter_value(self, value):
118
122
119
123
120
124
class MPPIParameter (TunableParameter , abc .ABC ):
121
- def __init__ (self , mppi : MPPI ):
125
+ def __init__ (self , mppi : MPPI , dim = None ):
122
126
self .mppi = mppi
123
- self .d = mppi .d
124
- self .dtype = mppi .dtype
127
+ self ._dim = dim
128
+ if self .mppi is not None :
129
+ self .d = self .mppi .d
130
+ self .dtype = self .mppi .dtype
131
+ if dim is None :
132
+ self ._dim = self .mppi .nu
133
+
134
+ def attach_to_state (self , state : dict ):
135
+ self .mppi = state ['mppi' ]
136
+ self .d = self .mppi .d
137
+ self .dtype = self .mppi .dtype
125
138
126
139
127
140
class SigmaParameter (MPPIParameter ):
@@ -132,10 +145,10 @@ def name():
132
145
return 'sigma'
133
146
134
147
def dim (self ):
135
- return self .mppi . nu
148
+ return self ._dim
136
149
137
150
def get_current_parameter_value (self ):
138
- return torch .cat ([self .mppi .noise_sigma [i ][i ].view (1 ) for i in range (self .mppi . nu )])
151
+ return torch .cat ([self .mppi .noise_sigma [i ][i ].view (1 ) for i in range (self .dim () )])
139
152
140
153
def ensure_valid_value (self , value ):
141
154
sigma = ensure_tensor (self .d , self .dtype , value )
@@ -149,10 +162,10 @@ def apply_parameter_value(self, value):
149
162
self .mppi .noise_sigma_inv = torch .inverse (self .mppi .noise_sigma .detach ())
150
163
151
164
def get_parameter_value_from_config (self , config ):
152
- return torch .tensor ([config [f'{ self .name ()} { i } ' ] for i in range (self .mppi . nu )], dtype = self .dtype , device = self .d )
165
+ return torch .tensor ([config [f'{ self .name ()} { i } ' ] for i in range (self .dim () )], dtype = self .dtype , device = self .d )
153
166
154
167
def get_config_from_parameter_value (self , value ):
155
- return {f'{ self .name ()} { i } ' : value [i ].item () for i in range (self .mppi . nu )}
168
+ return {f'{ self .name ()} { i } ' : value [i ].item () for i in range (self .dim () )}
156
169
157
170
158
171
class MuParameter (MPPIParameter ):
@@ -161,7 +174,7 @@ def name():
161
174
return 'mu'
162
175
163
176
def dim (self ):
164
- return self .mppi . nu
177
+ return self ._dim
165
178
166
179
def get_current_parameter_value (self ):
167
180
return self .mppi .noise_mu .clone ()
@@ -176,10 +189,10 @@ def apply_parameter_value(self, value):
176
189
self .mppi .noise_sigma_inv = torch .inverse (self .mppi .noise_sigma .detach ())
177
190
178
191
def get_parameter_value_from_config (self , config ):
179
- return torch .tensor ([config [f'{ self .name ()} { i } ' ] for i in range (self .mppi . nu )], dtype = self .dtype , device = self .d )
192
+ return torch .tensor ([config [f'{ self .name ()} { i } ' ] for i in range (self .dim () )], dtype = self .dtype , device = self .d )
180
193
181
194
def get_config_from_parameter_value (self , value ):
182
- return {f'{ self .name ()} { i } ' : value [i ].item () for i in range (self .mppi . nu )}
195
+ return {f'{ self .name ()} { i } ' : value [i ].item () for i in range (self .dim () )}
183
196
184
197
185
198
class LambdaParameter (MPPIParameter ):
@@ -236,15 +249,25 @@ class Autotune:
236
249
eps = 0.0001
237
250
238
251
def __init__ (self , params_to_tune : typing .Sequence [TunableParameter ],
239
- evaluate_fn : typing .Callable [[], EvaluationResult ], optimizer = CMAESOpt ()):
252
+ evaluate_fn : typing .Callable [[], EvaluationResult ],
253
+ reload_state_fn : typing .Callable [[], dict ] = None ,
254
+ optimizer = CMAESOpt ()):
255
+ """
256
+
257
+ :param params_to_tune: sequence of tunable parameters
258
+ :param evaluate_fn: function that returns an EvaluationResult that we want to minimize
259
+ :param reload_state_fn: function that returns a dictionary of state to reattach to the parameters
260
+ :param optimizer: optimizer that searches in the parameter space
261
+ """
240
262
self .evaluate_fn = evaluate_fn
263
+ self .reload_state_fn = reload_state_fn
241
264
242
265
self .params = params_to_tune
243
266
self .optim = optimizer
244
267
self .optim .tuner = self
245
268
self .results = []
246
269
247
- self .get_parameter_values ( self . params )
270
+ self .attach_parameters ( )
248
271
self .optim .setup_optimization ()
249
272
250
273
def optimize_step (self ) -> EvaluationResult :
@@ -303,6 +326,17 @@ def apply_parameters(self, param_values):
303
326
for p in self .params :
304
327
p .apply_parameter_value (param_values [p .name ()])
305
328
329
+ def attach_parameters (self ):
330
+ """Attach parameters to any underlying state they require In most cases the parameters are defined already
331
+ attached to whatever state it needs, e.g. the MPPI controller object for changing the parameter values.
332
+ However, there are cases where the full state is not serializable, e.g. when using a multiprocessing pool
333
+ and so we pass only the information required to load the state. We then must load the state and reattach
334
+ the parameters to the state each training iteration."""
335
+ if self .reload_state_fn is not None :
336
+ state = self .reload_state_fn ()
337
+ for p in self .params :
338
+ p .attach_to_state (state )
339
+
306
340
def config_to_params (self , config ):
307
341
"""Configs are param dictionaries where each must be a scalar"""
308
342
return {p .name (): p .get_parameter_value_from_config (config ) for p in self .params }
0 commit comments