Skip to content

Delay entering TrialRunner context until run_trial #970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mlos_bench/mlos_bench/schedulers/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ def __enter__(self) -> "Scheduler":
_LOG.debug("Scheduler START :: %s", self)
assert self.experiment is None
assert not self._in_context
for trial_runner in self._trial_runners.values():
trial_runner.__enter__()
# NOTE: We delay entering the context of trial_runners until it's time
# to run the trial in order to avoid incompatibilities with
# multiprocessing.Pool.
self._optimizer.__enter__()
# Start new or resume the existing experiment. Verify that the
# experiment configuration is compatible with the previous runs.
Expand Down Expand Up @@ -235,7 +236,8 @@ def __exit__(
self._experiment.__exit__(ex_type, ex_val, ex_tb)
self._optimizer.__exit__(ex_type, ex_val, ex_tb)
for trial_runner in self._trial_runners.values():
trial_runner.__exit__(ex_type, ex_val, ex_tb)
# TrialRunners should have already exited their context after running the Trial.
assert not trial_runner._in_context # pylint: disable=protected-access
self._experiment = None
self._in_context = False
return False # Do not suppress exceptions
Expand Down Expand Up @@ -267,7 +269,8 @@ def teardown(self) -> None:
if self._do_teardown:
for trial_runner in self._trial_runners.values():
assert not trial_runner.is_running
trial_runner.teardown()
with trial_runner:
trial_runner.teardown()

def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
"""Get the best observation from the optimizer."""
Expand Down
5 changes: 3 additions & 2 deletions mlos_bench/mlos_bench/schedulers/sync_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def run_trial(self, trial: Storage.Trial) -> None:
super().run_trial(trial)
# In the sync scheduler we run each trial on its own TrialRunner in sequence.
trial_runner = self.get_trial_runner(trial)
trial_runner.run_trial(trial, self.global_config)
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)
with trial_runner:
trial_runner.run_trial(trial, self.global_config)
_LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner)
Loading