|
| 1 | +# |
| 2 | +# Copyright (c) Microsoft Corporation. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# |
| 5 | +"""Simple parallel trial scheduler and optimization loop implementation stub code.""" |
| 6 | +import json |
| 7 | +import random |
| 8 | +from multiprocessing.pool import AsyncResult, Pool |
| 9 | +from time import sleep |
| 10 | +from typing import Any |
| 11 | + |
| 12 | + |
| 13 | +class TrialRunner: # pylint: disable=too-few-public-methods |
| 14 | + """Stub TrialRunner.""" |
| 15 | + |
| 16 | + def __init__(self, runner_id: int): |
| 17 | + self.runner_id = runner_id |
| 18 | + |
| 19 | + def run_trial(self, iteration: int, suggestion: int) -> dict[str, int | float]: |
| 20 | + """Stub run_trial.""" |
| 21 | + # In the real system we'd run the Trial on the Environment and whatnot. |
| 22 | + sleep_time = random.uniform(0, 1) + 0.01 |
| 23 | + print( |
| 24 | + ( |
| 25 | + f"Trial {iteration} is running on {self.runner_id} " |
| 26 | + f"with suggestion {suggestion} with sleep time {sleep_time}" |
| 27 | + ), |
| 28 | + flush=True, |
| 29 | + ) |
| 30 | + # Wait a moment to simulate the time it takes to run the trial. |
| 31 | + sleep(sleep_time) |
| 32 | + print(f"Trial {iteration} on {self.runner_id} is done.", flush=True) |
| 33 | + return { |
| 34 | + "runner_id": self.runner_id, |
| 35 | + "iteration": iteration, |
| 36 | + "suggestion": suggestion, |
| 37 | + "sleep_time": sleep_time, |
| 38 | + } |
| 39 | + |
| 40 | + |
| 41 | +class ParallelTrialScheduler: |
| 42 | + """Stub ParallelTrialScheduler.""" |
| 43 | + |
| 44 | + def __init__(self, num_trial_runners: int, max_iterations: int): |
| 45 | + self._max_iterations = max_iterations |
| 46 | + self._trial_runners = [TrialRunner(i) for i in range(num_trial_runners)] |
| 47 | + |
| 48 | + # Track the current status of a TrialRunner. |
| 49 | + # In a real system we might need to track which TrialRunner is busy in |
| 50 | + # the backend Storage in case of failures of the main process or else |
| 51 | + # just treat their state as idempotent such that we could resume and |
| 52 | + # check on their status at any time. |
| 53 | + # That would also require a deterministic scheduling algorithm so that |
| 54 | + # we restart the same Trial on the same TrialRunner rather than picking |
| 55 | + # a new one. |
| 56 | + self._trial_runners_status: dict[int, AsyncResult | None] = { |
| 57 | + runner.runner_id: None for runner in self._trial_runners |
| 58 | + } |
| 59 | + |
| 60 | + # Simple trial schedule: maps a trial id to a TrialRunner. |
| 61 | + # In the real system we'd store everything in the Storage backend. |
| 62 | + self._trial_schedule: dict[int, tuple[int, int]] = {} |
| 63 | + self._current_runner_id = 0 |
| 64 | + |
| 65 | + # Store all the results in a dictionary. |
| 66 | + # In the real system we'd submit them to the Storage and the Optimizer. |
| 67 | + self._results: dict[int, dict[str, int | float]] = {} |
| 68 | + |
| 69 | + def get_last_trial_id(self) -> int: |
| 70 | + """Very simple method of tracking the last trial id assigned.""" |
| 71 | + return max(list(self._results.keys()) + list(self._trial_schedule.keys()), default=-1) |
| 72 | + |
| 73 | + def is_done_scheduling(self) -> bool: |
| 74 | + """Check if the scheduler loop is done.""" |
| 75 | + # This is a simple stopping condition to check and see if we've |
| 76 | + # scheduled enough trials. |
| 77 | + return self.get_last_trial_id() + 1 >= self._max_iterations |
| 78 | + |
| 79 | + def is_done_running(self) -> bool: |
| 80 | + """Check if the scheduler run loop is done.""" |
| 81 | + # This is a simple stopping condition to check and see if we've |
| 82 | + # run all the trials. |
| 83 | + return len(self._results) >= self._max_iterations |
| 84 | + |
| 85 | + def assign_trial_runner(self, trial_id: int, suggestion: int) -> None: |
| 86 | + """Stub assign_trial_runner.""" |
| 87 | + # In a real system we'd have a more sophisticated way of assigning |
| 88 | + # trials to TrialRunners. |
| 89 | + # Here we just round-robin the suggestions to the available TrialRunners. |
| 90 | + next_runner_id = self._current_runner_id |
| 91 | + self._current_runner_id = (self._current_runner_id + 1) % len(self._trial_runners) |
| 92 | + self._trial_schedule[trial_id] = (next_runner_id, suggestion) |
| 93 | + print( |
| 94 | + f"Assigned trial {trial_id} to runner {next_runner_id} with suggestion {suggestion}", |
| 95 | + flush=True, |
| 96 | + ) |
| 97 | + |
| 98 | + def schedule_new_trials(self, num_new_trials: int = 1) -> None: |
| 99 | + """Stub schedule_new_trial(s).""" |
| 100 | + |
| 101 | + # Accept more than one new suggestion at a time to simulate a real |
| 102 | + # system that might be doing multi-objective pareto frontier |
| 103 | + # optimization. |
| 104 | + |
| 105 | + while num_new_trials > 0 and not self.is_done_scheduling(): |
| 106 | + # Generate one (or more) new suggestion(s). |
| 107 | + # In the real system we'd get these from the Optimizer. |
| 108 | + suggestion = random.randint(0, 100) |
| 109 | + |
| 110 | + # Note: it might be also be the case that we want to repeat that |
| 111 | + # suggestion multiple times on different TrialRunners. |
| 112 | + |
| 113 | + # Schedule it to a TrialRunner. |
| 114 | + next_trial_id = self.get_last_trial_id() + 1 |
| 115 | + self.assign_trial_runner(next_trial_id, suggestion) |
| 116 | + num_new_trials -= 1 |
| 117 | + |
| 118 | + def _run_trial_failed_callback(self, obj: Any) -> None: # pylint: disable=no-self-use |
| 119 | + """Stub callback to run when run_trial fails in pool process.""" |
| 120 | + raise RuntimeError(f"Trial failed: {obj}") |
| 121 | + |
| 122 | + def _run_trial_finished_callback(self, result: dict[str, int | float]) -> None: |
| 123 | + """Stub callback to run when run_trial finishes in pool process.""" |
| 124 | + |
| 125 | + # Store the result of the trial. |
| 126 | + trial_id = result["iteration"] |
| 127 | + assert isinstance(trial_id, int) |
| 128 | + self._results[trial_id] = result |
| 129 | + |
| 130 | + # Remove it from the schedule. |
| 131 | + self._trial_schedule.pop(trial_id) |
| 132 | + |
| 133 | + # And mark the TrialRunner as available. |
| 134 | + runner_id = result["runner_id"] |
| 135 | + assert isinstance(runner_id, int) |
| 136 | + trial_runner_status = self._trial_runners_status.get(runner_id) |
| 137 | + assert isinstance(trial_runner_status, AsyncResult) |
| 138 | + # assert trial_runner_status.ready() |
| 139 | + self._trial_runners_status[runner_id] = None |
| 140 | + |
| 141 | + print(f"Trial {trial_id} on {runner_id} callback is done.", flush=True) |
| 142 | + |
| 143 | + # Schedule more trials. |
| 144 | + # Note: this would schedule additional trials everytime one completes. |
| 145 | + # An alternative option would be to batch them up and schedule several |
| 146 | + # after a few complete. |
| 147 | + # The tradeoffs being model retraining time vs. waiting on straggler |
| 148 | + # workers vs. optimizer new suggestion accuracy. |
| 149 | + # Moreover, we need to handle the edge case and include scheduling in |
| 150 | + # the loop anyways, so it's probably better to just leave it all there. |
| 151 | + # self.schedule_new_trials(num_new_trials=1) |
| 152 | + |
| 153 | + def get_idle_trial_runners_count(self) -> int: |
| 154 | + """Stub get_idle_trial_runners_count.""" |
| 155 | + return len([x for x in self._trial_runners_status.values() if x is None]) |
| 156 | + |
| 157 | + def start_optimization_loop(self) -> None: |
| 158 | + """Stub start_optimization_loop.""" |
| 159 | + |
| 160 | + # Create a pool of processes to run the trials in parallel. |
| 161 | + with Pool(processes=len(self._trial_runners), maxtasksperchild=1) as pool: |
| 162 | + while not self.is_done_scheduling() or not self.is_done_running(): |
| 163 | + # Run any existing trials that aren't currently running. |
| 164 | + # Do this first in case we're resuming from a previous run |
| 165 | + # (e.g., the real system will have remembered which Trials were |
| 166 | + # in progress by reloading them from the Storage backend). |
| 167 | + |
| 168 | + # Avoid modifying the dictionary while iterating over it. |
| 169 | + trial_schedule = self._trial_schedule.copy() |
| 170 | + for trial_id, (runner_id, suggestion) in trial_schedule.items(): |
| 171 | + # Skip trials that are already running on their assigned TrialRunner. |
| 172 | + if self._trial_runners_status[runner_id] is not None: |
| 173 | + continue |
| 174 | + # Else, start the Trial on the given TrialRunner in the background. |
| 175 | + self._trial_runners_status[runner_id] = pool.apply_async( |
| 176 | + TrialRunner(runner_id).run_trial, |
| 177 | + args=(trial_id, suggestion), |
| 178 | + callback=self._run_trial_finished_callback, |
| 179 | + error_callback=self._run_trial_failed_callback, |
| 180 | + ) |
| 181 | + # Now all the available TrialRunners that had work to do should be running. |
| 182 | + |
| 183 | + # Wait a moment to check if we have any idle TrialRunners. |
| 184 | + # This also allows us a chance to collect multiple results from |
| 185 | + # the pool before suggesting new ones. |
| 186 | + while len(self._trial_schedule) > 0 and self.get_idle_trial_runners_count() == 0: |
| 187 | + # Make the polling interval here configurable. |
| 188 | + sleep(0.5) |
| 189 | + |
| 190 | + # Schedule more trials if we can. |
| 191 | + self.schedule_new_trials(num_new_trials=self.get_idle_trial_runners_count() or 1) |
| 192 | + |
| 193 | + # Should be all done starting new trials. |
| 194 | + print("Closing the pool.", flush=True) |
| 195 | + pool.close() |
| 196 | + |
| 197 | + print("Waiting for all trials to finish.", flush=True) |
| 198 | + # FIXME: This sometimes hangs. Not sure why yet. |
| 199 | + pool.join() |
| 200 | + |
| 201 | + print("Optimization loop is done.", flush=True) |
| 202 | + print("results: " + json.dumps(self._results, indent=2)) |
| 203 | + print("trial_schedule: " + json.dumps(self._trial_schedule, indent=2)) |
| 204 | + print("trial_runner_status: " + json.dumps(self._trial_runners_status, indent=2)) |
| 205 | + assert len(self._results) == self._max_iterations, "Unexpected number of trials run." |
| 206 | + assert not self._trial_schedule, "Some scheduled trials were not started." |
| 207 | + assert all( |
| 208 | + x is None for x in self._trial_runners_status.values() |
| 209 | + ), "Some TrialRunners are still running." |
| 210 | + |
| 211 | + |
| 212 | +def main(): |
| 213 | + """Main function.""" |
| 214 | + print("Starting ParallelTrialScheduler.", flush=True) |
| 215 | + scheduler = ParallelTrialScheduler(num_trial_runners=4, max_iterations=15) |
| 216 | + scheduler.start_optimization_loop() |
| 217 | + |
| 218 | + |
| 219 | +if __name__ == "__main__": |
| 220 | + main() |
0 commit comments