diff --git a/mlos_bench/mlos_bench/config/schedulers/parallel_scheduler.jsonc b/mlos_bench/mlos_bench/config/schedulers/parallel_scheduler.jsonc new file mode 100644 index 00000000000..4d6b7a7e272 --- /dev/null +++ b/mlos_bench/mlos_bench/config/schedulers/parallel_scheduler.jsonc @@ -0,0 +1,12 @@ +// Mock optimizer to test the benchmarking framework. +{ + "$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json", + + "class": "mlos_bench.schedulers.ParallelScheduler", + + "config": { + "trial_config_repeat_count": 3, + "max_trials": -1, // Limited only in the Optimizer logic/config. + "teardown": false + } +} diff --git a/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json b/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json index 81b2e797547..dedac1ed758 100644 --- a/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json +++ b/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json @@ -2,12 +2,10 @@ "$schema": "https://json-schema.org/draft/2020-12/schema", "$id": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json", "title": "mlos_bench Scheduler config", - "$defs": { "comment": { "$comment": "This section contains reusable partial schema bits (or just split out for readability)" }, - "config_base_scheduler": { "$comment": "config properties common to all Scheduler types.", "type": "object", @@ -29,18 +27,23 @@ "description": "Max. number of trials to run. Use -1 or 0 for unlimited.", "type": "integer", "minimum": -1, - "examples": [50, -1] + "examples": [ + 50, + -1 + ] }, "trial_config_repeat_count": { "description": "Number of times to repeat a config.", "type": "integer", "minimum": 1, - "examples": [3, 5] + "examples": [ + 3, + 5 + ] } } } }, - "description": "config for the mlos_bench scheduler", "$comment": "top level schema document rules", "type": "object", @@ -51,21 +54,20 @@ "$comment": "This is optional, but if provided, should match the name of this file.", "pattern": "/schemas/schedulers/scheduler-schema.json$" }, - "description": { "description": "Optional description of the config.", "type": "string" }, - "class": { "description": "The name of the scheduler class to use.", "$comment": "required", "enum": [ "mlos_bench.schedulers.SyncScheduler", - "mlos_bench.schedulers.sync_scheduler.SyncScheduler" + "mlos_bench.schedulers.sync_scheduler.SyncScheduler", + "mlos_bench.schedulers.ParallelScheduler", + "mlos_bench.schedulers.parallel_scheduler.ParallelScheduler" ] }, - "config": { "description": "The scheduler-specific config.", "$comment": "Stub for scheduler-specific config appended with condition statements below", @@ -73,8 +75,9 @@ "minProperties": 1 } }, - "required": ["class"], - + "required": [ + "class" + ], "oneOf": [ { "$comment": "extensions to the 'config' object properties when synchronous scheduler is being used", @@ -83,17 +86,25 @@ "class": { "enum": [ "mlos_bench.schedulers.SyncScheduler", - "mlos_bench.schedulers.sync_scheduler.SyncScheduler" + "mlos_bench.schedulers.sync_scheduler.SyncScheduler", + "mlos_bench.schedulers.ParallelScheduler", + "mlos_bench.schedulers.parallel_scheduler.ParallelScheduler" ] } }, - "required": ["class"] + "required": [ + "class" + ] }, "then": { "properties": { "config": { "type": "object", - "allOf": [{ "$ref": "#/$defs/config_base_scheduler" }], + "allOf": [ + { + "$ref": "#/$defs/config_base_scheduler" + } + ], "$comment": "disallow other properties", "unevaluatedProperties": false } diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index ca35b3473da..066b659f154 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -18,6 +18,7 @@ class Status(enum.Enum): CANCELED = 5 FAILED = 6 TIMED_OUT = 7 + SCHEDULED = 8 def is_good(self) -> bool: """Check if the status of the benchmark/environment is good.""" @@ -26,6 +27,7 @@ def is_good(self) -> bool: Status.READY, Status.RUNNING, Status.SUCCEEDED, + Status.SCHEDULED, } def is_completed(self) -> bool: @@ -74,3 +76,9 @@ def is_timed_out(self) -> bool: TIMED_OUT. """ return self == Status.FAILED + + def is_scheduled(self) -> bool: + """Check if the status of the benchmark/environment Trial or Experiment is + SCHEDULED. + """ + return self == Status.SCHEDULED diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index 381261e53da..fd381612be7 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -5,9 +5,11 @@ """Interfaces and implementations of the optimization loop scheduling policies.""" from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.schedulers.parallel_scheduler import ParallelScheduler from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ "Scheduler", "SyncScheduler", + "ParallelScheduler", ] diff --git a/mlos_bench/mlos_bench/schedulers/parallel_scheduler.py b/mlos_bench/mlos_bench/schedulers/parallel_scheduler.py new file mode 100644 index 00000000000..2092d5f5bb8 --- /dev/null +++ b/mlos_bench/mlos_bench/schedulers/parallel_scheduler.py @@ -0,0 +1,162 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""A simple multi-threaded asynchronous optimization loop implementation.""" + +import asyncio +import logging +from collections.abc import Callable +from concurrent.futures import Future, ProcessPoolExecutor +from datetime import datetime +from typing import Any + +from pytz import UTC + +from mlos_bench.environments.status import Status +from mlos_bench.schedulers.base_scheduler import Scheduler +from mlos_bench.schedulers.trial_runner import TrialRunner +from mlos_bench.storage.base_storage import Storage +from mlos_bench.tunables.tunable_groups import TunableGroups + +_LOG = logging.getLogger(__name__) + + +class ParallelScheduler(Scheduler): + """A simple multi-process asynchronous optimization loop implementation.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + + super().__init__(*args, **kwargs) + self.pool = ProcessPoolExecutor(max_workers=len(self._trial_runners)) + + def start(self) -> None: + """Start the optimization loop.""" + super().start() + + is_warm_up: bool = self.optimizer.supports_preload + if not is_warm_up: + _LOG.warning("Skip pending trials and warm-up: %s", self.optimizer) + + not_done: bool = True + while not_done: + _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id) + self._run_callbacks() + self._run_schedule(is_warm_up) + not_done = self._schedule_new_optimizer_suggestions() + is_warm_up = False + + def teardown(self) -> None: + """Stop the optimization loop.""" + # Shutdown the thread pool and wait for all tasks to finish + self.pool.shutdown(wait=True) + self._run_callbacks() + super().teardown() + + def schedule_trial(self, tunables: TunableGroups) -> None: + """Assign a trial to a trial runner.""" + assert self.experiment is not None + + super().schedule_trial(tunables) + + pending_trials: list[Storage.Trial] = list( + self.experiment.pending_trials(datetime.now(UTC), running=False) + ) + + idle_runner_ids = [ + id for id, runner in self.trial_runners.items() if not runner.is_running + ] + + # Assign pending trials to idle runners + for trial, runner_id in zip(pending_trials, idle_runner_ids): + trial.update(status=Status.SCHEDULED, timestamp=datetime.now(UTC)) + trial.set_trial_runner(runner_id) + + def _run_schedule(self, running: bool = False) -> None: + """ + Scheduler part of the loop. + + Check for pending trials in the queue and run them. + """ + assert self.experiment is not None + + scheduled_trials: list[Storage.Trial] = list( + self.experiment.filter_trials_by_status(datetime.now(UTC), [Status.SCHEDULED]) + ) + + for trial in scheduled_trials: + trial.update(status=Status.READY, timestamp=datetime.now(UTC)) + self.deferred_run_trial(trial) + + def _on_trial_finished_closure( + self, trial: Storage.Trial + ) -> Callable[["ParallelScheduler", Future], None]: + """ + Generate a closure to handle the callback for when a trial is finished. + + Parameters + ---------- + trial : Storage.Trial + The trial to finish. + """ + + def _on_trial_finished(self: ParallelScheduler, result: Future) -> None: + """ + Callback to be called when a trial is finished. + + This must always be called from the main thread. Exceptions can also be + handled here + """ + try: + (status, timestamp, results, telemetry) = result.result() + self.get_trial_runner(trial).finalize_run_trial( + trial, status, timestamp, results, telemetry + ) + except Exception as exception: # pylint: disable=broad-except + _LOG.error("Trial failed: %s", exception) + + return _on_trial_finished + + @staticmethod + def _run_callbacks() -> None: + """Run all pending callbacks in the main thread.""" + loop = asyncio.get_event_loop() + pending = asyncio.all_tasks(loop) + loop.run_until_complete(asyncio.gather(*pending)) + + def run_trial(self, trial: Storage.Trial) -> None: + """ + Parallel Scheduler does not support run_trial. Use async_run_trial instead. + + Parameters + ---------- + trial : Storage.Trial + The trial to run. + + Raises + ------ + NotImplementedError + Error to indicate that this method is not supported in ParallelScheduler. + """ + raise NotImplementedError( + "ParallelScheduler does not support run_trial. Use async_run_trial instead." + ) + + def deferred_run_trial(self, trial: Storage.Trial) -> None: + """ + Set up and run a single trial asynchronously. + + Returns a callback to save the results in the storage. + """ + super().run_trial(trial) + # In the sync scheduler we run each trial on its own TrialRunner in sequence. + trial_runner = self.get_trial_runner(trial) + trial_runner.prepare_run_trial(trial, self.global_config) + + task = self.pool.submit(TrialRunner.execute_run_trial, trial_runner.environment) + # This is required to ensure that the callback happens on the main thread + asyncio.get_event_loop().call_soon_threadsafe( + self._on_trial_finished_closure(trial), self, task + ) + + _LOG.info("QUEUE: Finished trial: %s on %s", trial, trial_runner) diff --git a/mlos_bench/mlos_bench/schedulers/trial_runner.py b/mlos_bench/mlos_bench/schedulers/trial_runner.py index 80eb696bc6d..3c5e62e690a 100644 --- a/mlos_bench/mlos_bench/schedulers/trial_runner.py +++ b/mlos_bench/mlos_bench/schedulers/trial_runner.py @@ -13,13 +13,13 @@ from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.status import Status -from mlos_bench.event_loop_context import EventLoopContext from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.services.types import SupportsConfigLoading from mlos_bench.storage.base_storage import Storage from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.tunables.tunable_types import TunableValue _LOG = logging.getLogger(__name__) @@ -117,7 +117,6 @@ def __init__(self, trial_runner_id: int, env: Environment) -> None: assert self._env.parameters["trial_runner_id"] == self._trial_runner_id self._in_context = False self._is_running = False - self._event_loop_context = EventLoopContext() def __repr__(self) -> str: return ( @@ -164,26 +163,20 @@ def is_running(self) -> bool: """Get the running state of the current TrialRunner.""" return self._is_running - def run_trial( + def prepare_run_trial( self, trial: Storage.Trial, global_config: dict[str, Any] | None = None, ) -> None: """ - Run a single trial on this TrialRunner's Environment and stores the results in - the backend Trial Storage. + Prepare the trial runner for running a trial. Parameters ---------- trial : Storage.Trial - A Storage class based Trial used to persist the experiment trial data. - global_config : dict - Global configuration parameters. - - Returns - ------- - (trial_status, trial_score) : (Status, dict[str, float] | None) - Status and results of the trial. + The trial to prepare. + global_config : dict[str, Any] | None + Global configuration parameters, by default None """ assert self._in_context @@ -196,31 +189,88 @@ def run_trial( ) if not self.environment.setup(trial.tunables, trial.config(global_config)): - _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables) - # FIXME: Use the actual timestamp from the environment. - _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, Status.FAILED) trial.update(Status.FAILED, datetime.now(UTC)) - return - # TODO: start background status polling of the environments in the event loop. + @staticmethod + def execute_run_trial( + environment: Environment, + ) -> tuple[Status, datetime, dict[str, TunableValue] | None, list[tuple[datetime, str, Any]]]: + """ + Execute the trial run on the environment. + + Parameters + ---------- + environment : Environment + The environment to run the trial on. + Returns + ------- + tuple[ + Status, + datetime.datetime, + dict[str, TunableValue] | None, + list[tuple[datetime.datetime, str, Any]] + ] + The full results of the trial run, including status, timestamp, results, and telemetry. + """ # Block and wait for the final result. - (status, timestamp, results) = self.environment.run() - _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results) + (status, timestamp, results) = environment.run() # In async mode (TODO), poll the environment for status and telemetry # and update the storage with the intermediate results. - (_status, _timestamp, telemetry) = self.environment.status() + (_status, _timestamp, telemetry) = environment.status() - # Use the status and timestamp from `.run()` as it is the final status of the experiment. - # TODO: Use the `.status()` output in async mode. - trial.update_telemetry(status, timestamp, telemetry) + return (status, timestamp, results, telemetry) + + def finalize_run_trial( # pylint: disable=too-many-arguments, too-many-positional-arguments + self, + trial: Storage.Trial, + status: Status, + timestamp: datetime, + results: dict[str, TunableValue] | None, + telemetry: list[tuple[datetime, str, Any]], + ) -> None: + """ + Finalize the trial run in the storage backend. + Parameters + ---------- + trial : Storage.Trial + The trial to finalize. + status : Status + The status of the trial. + timestamp : datetime.datetime + The timestamp of the trial execution. + results : dict[str, TunableValue] | None, + The results of the trial + telemetry : list[tuple[datetime.datetime, str, Any]] + The telemetry data of the trial. + """ + trial.update_telemetry(status, timestamp, telemetry) trial.update(status, timestamp, results) _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results) - self._is_running = False + def run_trial( + self, + trial: Storage.Trial, + global_config: dict[str, Any] | None = None, + ) -> None: + """ + Run a single trial on this TrialRunner's Environment and store the results in + the backend Trial Storage. + + Parameters + ---------- + trial : Storage.Trial + A Storage class based Trial used to persist the experiment trial data. + global_config : dict + Global configuration parameters. + """ + self.prepare_run_trial(trial, global_config) + (status, timestamp, results, telemetry) = self.execute_run_trial(self._env) + self.finalize_run_trial(trial, status, timestamp, results, telemetry) + def teardown(self) -> None: """ Tear down the Environment. diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index f2d393994f7..87d61a6723e 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -307,6 +307,29 @@ def load( Trial ids, Tunable values, benchmark scores, and status of the trials. """ + @abstractmethod + def filter_trials_by_status( + self, + timestamp: datetime, + statuses: list[Status], + ) -> Iterator["Storage.Trial"]: + """ + Return an iterator over the pending trials that are scheduled to run on or + before the specified timestamp matching one of statuses listed. + + Parameters + ---------- + timestamp : datetime.datetime + The time in UTC to check for scheduled trials. + statuses : list[Status] + Status of the trials to filter in. + + Returns + ------- + trials : Iterator[Storage.Trial] + An iterator over the matching trials. + """ + @abstractmethod def pending_trials( self, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index eb47de7d714..97581b949cb 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -235,13 +235,13 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> d row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access ) - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: + def filter_trials_by_status( + self, + timestamp: datetime, + statuses: list[Status], + ) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) - if running: - pending_status = [Status.PENDING.name, Status.READY.name, Status.RUNNING.name] - else: - pending_status = [Status.PENDING.name] with self._engine.connect() as conn: cur_trials = conn.execute( self._schema.trial.select().where( @@ -251,7 +251,7 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor | (self._schema.trial.c.ts_start <= timestamp) ), self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), + self._schema.trial.c.status.in_([s.name for s in statuses]), ) ) for trial in cur_trials.fetchall(): @@ -281,6 +281,13 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor config=config, ) + def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: + if running: + pending_status = [Status.PENDING, Status.READY, Status.RUNNING] + else: + pending_status = [Status.PENDING] + return self.filter_trials_by_status(timestamp=timestamp, statuses=pending_status) + def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: """ Get the config ID for the given tunables. diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-bad-repeat.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-bad-repeat.jsonc new file mode 100644 index 00000000000..4ea6bdbf170 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-bad-repeat.jsonc @@ -0,0 +1,6 @@ +{ + "class": "mlos_bench.schedulers.ParallelScheduler", + "config": { + "trial_config_repeat_count": 0 + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-empty-config.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-empty-config.jsonc new file mode 100644 index 00000000000..06729a4f368 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/invalid/parallel_sched-empty-config.jsonc @@ -0,0 +1,5 @@ +{ + "class": "mlos_bench.schedulers.ParallelScheduler", + "config": { + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/unhandled/parallel_sched-extra.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/unhandled/parallel_sched-extra.jsonc new file mode 100644 index 00000000000..68623ee611f --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/bad/unhandled/parallel_sched-extra.jsonc @@ -0,0 +1,6 @@ +{ + "class": "mlos_bench.schedulers.parallel_scheduler.ParallelScheduler", + "config": { + "extra": "unsupported" + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/parallel_sched-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/parallel_sched-full.jsonc new file mode 100644 index 00000000000..90bac645032 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/parallel_sched-full.jsonc @@ -0,0 +1,12 @@ +{ + "$schema": "https://raw.githubusercontent.com/microsoft/MLOS/main/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json", + "class": "mlos_bench.schedulers.parallel_scheduler.ParallelScheduler", + "config": { + "trial_config_repeat_count": 3, + "teardown": false, + "experiment_id": "MyExperimentName", + "config_id": 1, + "trial_id": 1, + "max_trials": 100 + } +} diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/partial/parallel_sched-partial.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/partial/parallel_sched-partial.jsonc new file mode 100644 index 00000000000..1b0e39c3305 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/partial/parallel_sched-partial.jsonc @@ -0,0 +1,7 @@ +{ + "class": "mlos_bench.schedulers.ParallelScheduler", + "config": { + "trial_config_repeat_count": 3, + "teardown": false + } +} diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 6294ee8bf3b..b84245732f5 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -18,7 +18,7 @@ from mlos_bench.launcher import Launcher from mlos_bench.optimizers import MlosCoreOptimizer, OneShotOptimizer from mlos_bench.os_environ import environ -from mlos_bench.schedulers import SyncScheduler +from mlos_bench.schedulers import ParallelScheduler, SyncScheduler from mlos_bench.services.types import ( SupportsAuth, SupportsConfigLoading, @@ -307,5 +307,67 @@ def test_launcher_args_parse_3(config_paths: list[str]) -> None: assert launcher.scheduler.trial_config_repeat_count == 2 +def test_launcher_args_parse_4(config_paths: list[str]) -> None: + """ + Test that using multiple --globals arguments works and that multiple space separated + options to --config-paths works. + + Check $var expansion and Environment loading. + """ + # Here we have multiple paths following --config-paths and --service. + cli_args = ( + "--config-paths " + + " ".join(config_paths) + + " --num-trial-runners 5" + + " --service services/remote/mock/mock_auth_service.jsonc" + " services/remote/mock/mock_remote_exec_service.jsonc" + " --scheduler schedulers/parallel_scheduler.jsonc" + f" --environment {ENV_CONF_PATH}" + " --globals globals/global_test_config.jsonc" + " --globals globals/global_test_extra_config.jsonc" + " --test_global_value_2 from-args" + ) + launcher = _get_launcher(__name__, cli_args) + # Check some additional features of the the parent service + assert isinstance(launcher.service, SupportsAuth) # from --service + assert isinstance(launcher.service, SupportsRemoteExec) # from --service + # Check that the first --globals file is loaded and $var expansion is handled. + assert launcher.global_config["experiment_id"] == "MockExperiment" + assert launcher.global_config["testVmName"] == "MockExperiment-vm" + # Check that secondary expansion also works. + assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" + # Check that the second --globals file is loaded. + assert launcher.global_config["test_global_value"] == "from-file" + # Check overriding values in a file from the command line. + assert launcher.global_config["test_global_value_2"] == "from-args" + # Check that we can expand a $var in a config file that references an environment variable. + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( + os.getcwd(), "foo", abs_path=True + ) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" + assert launcher.teardown + # Make sure we have the right number of trial runners. + assert len(launcher.trial_runners) == 5 # from cli args + # Check that the environment that got loaded looks to be of the right type. + env_config = launcher.config_loader.load_config(ENV_CONF_PATH, ConfigSchema.ENVIRONMENT) + assert env_config["class"] == "mlos_bench.environments.mock_env.MockEnv" + # All TrialRunners should get the same Environment. + assert all( + check_class_name(trial_runner.environment, env_config["class"]) + for trial_runner in launcher.trial_runners + ) + # Check that the optimizer looks right. + assert isinstance(launcher.optimizer, OneShotOptimizer) + # Check that the optimizer got initialized with defaults. + assert launcher.optimizer.tunable_params.is_defaults() + assert launcher.optimizer.max_suggestions == 1 # value for OneShotOptimizer + # Check that we pick up the right scheduler config: + assert isinstance(launcher.scheduler, ParallelScheduler) + assert ( + launcher.scheduler.trial_config_repeat_count == 3 + ) # from the custom sync_scheduler.jsonc config + assert launcher.scheduler.max_trials == -1 + + if __name__ == "__main__": pytest.main([__file__, "-n0"]) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index a1437052823..212bf4acd4c 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -16,5 +16,7 @@ exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_data = sql_storage_fixtures.exp_data +parallel_exp_data = sql_storage_fixtures.parallel_exp_data + exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index cb83bffd4ff..3cec974fcf5 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -4,12 +4,13 @@ # """Test fixtures for mlos_bench storage.""" -from collections.abc import Generator +from collections.abc import Callable, Generator from random import seed as rand_seed import pytest from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.schedulers.parallel_scheduler import ParallelScheduler from mlos_bench.schedulers.sync_scheduler import SyncScheduler from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.services.config_persistence import ConfigPersistenceService @@ -109,6 +110,94 @@ def mixed_numerics_exp_storage( assert not exp._in_context +def _parallel_dummy_run_exp( + storage: SqlStorage, + exp: SqlStorage.Experiment, +) -> ExperimentData: + """ + Generates data by doing a simulated run of the given experiment. + + Parameters + ---------- + storage : SqlStorage + The storage object to use. + exp : SqlStorage.Experiment + The experiment to "run". + Note: this particular object won't be updated, but a new one will be created + from its metadata. + + Returns + ------- + ExperimentData + The data generated by the simulated run. + """ + # pylint: disable=too-many-locals + + rand_seed(SEED) + + trial_runners: list[TrialRunner] = [] + global_config: dict = {} + config_loader = ConfigPersistenceService() + tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names()) + mock_env_json = f""" + {{ + "class": "mlos_bench.environments.mock_env.MockEnv", + "name": "Test Env", + "config": {{ + "tunable_params": [{tunable_params}], + "mock_env_seed": {SEED}, + "mock_env_range": [60, 120], + "mock_env_metrics": ["score"] + }} + }} + """ + trial_runners = TrialRunner.create_from_json( + config_loader=config_loader, + global_config=global_config, + tunable_groups=exp.tunables, + env_json=mock_env_json, + svcs_json=None, + num_trial_runners=TRIAL_RUNNER_COUNT, + ) + + opt = MockOptimizer( + tunables=exp.tunables, + config={ + "optimization_targets": exp.opt_targets, + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the + # default values for the tunable params) + # "start_with_defaults": True, + "max_suggestions": MAX_TRIALS, + }, + global_config=global_config, + ) + + scheduler = ParallelScheduler( + # All config values can be overridden from global config + config={ + "experiment_id": exp.experiment_id, + "trial_id": exp.trial_id, + "config_id": -1, + "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, + "max_trials": MAX_TRIALS, + }, + global_config=global_config, + trial_runners=trial_runners, + optimizer=opt, + storage=storage, + root_env_config=exp.root_env_config, + ) + + # Add some trial data to that experiment by "running" it. + with scheduler: + scheduler.start() + scheduler.teardown() + + return storage.experiments[exp.experiment_id] + + def _dummy_run_exp( storage: SqlStorage, exp: SqlStorage.Experiment, @@ -197,13 +286,49 @@ def _dummy_run_exp( return storage.experiments[exp.experiment_id] +def _exp_data( + storage: SqlStorage, + exp_storage: SqlStorage.Experiment, + run_exp: Callable[[SqlStorage, SqlStorage.Experiment], ExperimentData] = _dummy_run_exp, +) -> ExperimentData: + """Test fixture for ExperimentData.""" + return run_exp(storage, exp_storage) + + +def _exp_no_tunables_data( + storage: SqlStorage, + exp_no_tunables_storage: SqlStorage.Experiment, + run_exp: Callable[[SqlStorage, SqlStorage.Experiment], ExperimentData] = _dummy_run_exp, +) -> ExperimentData: + """Test fixture for ExperimentData with no tunable configs.""" + return run_exp(storage, exp_no_tunables_storage) + + +def _mixed_numerics_exp_data( + storage: SqlStorage, + mixed_numerics_exp_storage: SqlStorage.Experiment, + run_exp: Callable[[SqlStorage, SqlStorage.Experiment], ExperimentData] = _dummy_run_exp, +) -> ExperimentData: + """Test fixture for ExperimentData with mixed numerical tunable types.""" + return run_exp(storage, mixed_numerics_exp_storage) + + @pytest.fixture def exp_data( storage: SqlStorage, exp_storage: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData.""" - return _dummy_run_exp(storage, exp_storage) + return _exp_data(storage, exp_storage) + + +@pytest.fixture +def parallel_exp_data( + storage: SqlStorage, + exp_storage: SqlStorage.Experiment, +) -> ExperimentData: + """Test fixture for ExperimentData with parallel scheduling.""" + return _exp_data(storage, exp_storage, run_exp=_parallel_dummy_run_exp) @pytest.fixture @@ -212,7 +337,7 @@ def exp_no_tunables_data( exp_no_tunables_storage: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData with no tunable configs.""" - return _dummy_run_exp(storage, exp_no_tunables_storage) + return _exp_no_tunables_data(storage, exp_no_tunables_storage) @pytest.fixture @@ -221,4 +346,4 @@ def mixed_numerics_exp_data( mixed_numerics_exp_storage: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData with mixed numerical tunable types.""" - return _dummy_run_exp(storage, mixed_numerics_exp_storage) + return _mixed_numerics_exp_data(storage, mixed_numerics_exp_storage) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index aaf545c787f..804f0ef9d00 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -3,14 +3,17 @@ # Licensed under the MIT License. # """Unit tests for scheduling trials for some future time.""" -from collections.abc import Iterator +from collections.abc import Callable, Iterator from datetime import datetime, timedelta +from typing import Any +import numpy as np from pytz import UTC from mlos_bench.environments.status import Status from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_storage import Storage +from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.tests.storage import ( CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT, @@ -173,3 +176,29 @@ def test_rr_scheduling(exp_data: ExperimentData) -> None: assert ( trial.trial_runner_id == expected_runner_id ), f"Expected trial_runner_id {expected_runner_id} for {trial}" + + +def test_parallel_scheduling(parallel_exp_data: ExperimentData) -> None: + """ + Checks that the scheduler schedules all of Trials across TrialRunners. + + Note that we can no longer assume the order of the trials, since they can complete + in any order. + """ + extractor: Callable[[Callable[[TrialData], Any]], list[Any]] = lambda fn: [ + fn(parallel_exp_data.trials[id]) + for id in range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1) + ] + + trial_ids = extractor(lambda trial: trial.trial_id) + assert set(trial_ids) == set(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1)) + + config_ids = extractor(lambda trial: trial.tunable_config_id) + unique_config_ids, config_counts = np.unique(config_ids, return_counts=True) + assert len(unique_config_ids) == CONFIG_COUNT + assert all(count == CONFIG_TRIAL_REPEAT_COUNT for count in config_counts) + + repeat_nums = extractor(lambda trial: trial.metadata_dict["repeat_i"]) + unique_repeat_nums, repeat_nums_counts = np.unique(repeat_nums, return_counts=True) + assert len(unique_repeat_nums) == CONFIG_TRIAL_REPEAT_COUNT + assert all(count == CONFIG_COUNT for count in repeat_nums_counts)