Skip to content

Commit 19028d8

Browse files
authored
Delay entering TrialRunner context until run_trial (#970)
# Pull Request ## Title Delay entering `TrialRunner` context until `run_trial`. ______________________________________________________________________ ## Description This is part of an attempt to try and see if can work around issues with `multiprocessing.Pool` needing to pickle certain objects when forking. For instance, if the Environment is using an SshServer, we need to start an EventLoopContext in the background to handle the SSH connections and threads are not picklable. Nor are file handles, DB connections, etc., so there may be other things we also need to adjust to make this work. See Also #967 ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🔄 Refactor ______________________________________________________________________ ## Testing - Light so far (still in draft mode) - Just basic existing CI tests (seems to not break anything) ______________________________________________________________________ ## Additional Notes (optional) I think this is incomplete. To support forking inside the Scheduler and *then* entering the context of the given TrialRunner, we may also need to do something about the Scheduler's Storage object. That was true, those PRs are now forthcoming. See Also #971 For now this is a draft PR to allow @jsfreischuetz and I to play with alternative organizations of #967. ______________________________________________________________________
1 parent 2fbaaca commit 19028d8

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

mlos_bench/mlos_bench/schedulers/base_scheduler.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ def __enter__(self) -> "Scheduler":
200200
_LOG.debug("Scheduler START :: %s", self)
201201
assert self.experiment is None
202202
assert not self._in_context
203-
for trial_runner in self._trial_runners.values():
204-
trial_runner.__enter__()
203+
# NOTE: We delay entering the context of trial_runners until it's time
204+
# to run the trial in order to avoid incompatibilities with
205+
# multiprocessing.Pool.
205206
self._optimizer.__enter__()
206207
# Start new or resume the existing experiment. Verify that the
207208
# experiment configuration is compatible with the previous runs.
@@ -235,7 +236,8 @@ def __exit__(
235236
self._experiment.__exit__(ex_type, ex_val, ex_tb)
236237
self._optimizer.__exit__(ex_type, ex_val, ex_tb)
237238
for trial_runner in self._trial_runners.values():
238-
trial_runner.__exit__(ex_type, ex_val, ex_tb)
239+
# TrialRunners should have already exited their context after running the Trial.
240+
assert not trial_runner._in_context # pylint: disable=protected-access
239241
self._experiment = None
240242
self._in_context = False
241243
return False # Do not suppress exceptions
@@ -267,7 +269,8 @@ def teardown(self) -> None:
267269
if self._do_teardown:
268270
for trial_runner in self._trial_runners.values():
269271
assert not trial_runner.is_running
270-
trial_runner.teardown()
272+
with trial_runner:
273+
trial_runner.teardown()
271274

272275
def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
273276
"""Get the best observation from the optimizer."""

mlos_bench/mlos_bench/schedulers/sync_scheduler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ def run_trial(self, trial: Storage.Trial) -> None:
3939
super().run_trial(trial)
4040
# In the sync scheduler we run each trial on its own TrialRunner in sequence.
4141
trial_runner = self.get_trial_runner(trial)
42-
trial_runner.run_trial(trial, self.global_config)
43-
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)
42+
with trial_runner:
43+
trial_runner.run_trial(trial, self.global_config)
44+
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)

0 commit comments

Comments
 (0)