Skip to content

Commit 713bcc1

Browse files
committed
Adding an example forking parallel trial scheduler.
1 parent 1dcfca2 commit 713bcc1

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
#
5+
"""Simple parallel trial scheduler and optimization loop implementation stub code."""
6+
import json
7+
import random
8+
from multiprocessing.pool import AsyncResult, Pool
9+
from time import sleep
10+
from typing import Any
11+
12+
13+
class TrialRunner: # pylint: disable=too-few-public-methods
14+
"""Stub TrialRunner."""
15+
16+
def __init__(self, runner_id: int):
17+
self.runner_id = runner_id
18+
19+
def run_trial(self, iteration: int, suggestion: int) -> dict[str, int | float]:
20+
"""Stub run_trial."""
21+
# In the real system we'd run the Trial on the Environment and whatnot.
22+
sleep_time = random.uniform(0, 1) + 0.01
23+
print(
24+
(
25+
f"Trial {iteration} is running on {self.runner_id} "
26+
f"with suggestion {suggestion} with sleep time {sleep_time}"
27+
),
28+
flush=True,
29+
)
30+
# Wait a moment to simulate the time it takes to run the trial.
31+
sleep(sleep_time)
32+
print(f"Trial {iteration} on {self.runner_id} is done.", flush=True)
33+
return {
34+
"runner_id": self.runner_id,
35+
"iteration": iteration,
36+
"suggestion": suggestion,
37+
"sleep_time": sleep_time,
38+
}
39+
40+
41+
class ParallelTrialScheduler:
42+
"""Stub ParallelTrialScheduler."""
43+
44+
def __init__(self, num_trial_runners: int, max_iterations: int):
45+
self._max_iterations = max_iterations
46+
self._trial_runners = [TrialRunner(i) for i in range(num_trial_runners)]
47+
48+
# Track the current status of a TrialRunner.
49+
# In a real system we might need to track which TrialRunner is busy in
50+
# the backend Storage in case of failures of the main process or else
51+
# just treat their state as idempotent such that we could resume and
52+
# check on their status at any time.
53+
# That would also require a deterministic scheduling algorithm so that
54+
# we restart the same Trial on the same TrialRunner rather than picking
55+
# a new one.
56+
self._trial_runners_status: dict[int, AsyncResult | None] = {
57+
runner.runner_id: None for runner in self._trial_runners
58+
}
59+
60+
# Simple trial schedule: maps a trial id to a TrialRunner.
61+
# In the real system we'd store everything in the Storage backend.
62+
self._trial_schedule: dict[int, tuple[int, int]] = {}
63+
self._current_runner_id = 0
64+
65+
# Store all the results in a dictionary.
66+
# In the real system we'd submit them to the Storage and the Optimizer.
67+
self._results: dict[int, dict[str, int | float]] = {}
68+
69+
def get_last_trial_id(self) -> int:
70+
"""Very simple method of tracking the last trial id assigned."""
71+
return max(list(self._results.keys()) + list(self._trial_schedule.keys()), default=-1)
72+
73+
def is_done_scheduling(self) -> bool:
74+
"""Check if the scheduler loop is done."""
75+
# This is a simple stopping condition to check and see if we've
76+
# scheduled enough trials.
77+
return self.get_last_trial_id() + 1 >= self._max_iterations
78+
79+
def is_done_running(self) -> bool:
80+
"""Check if the scheduler run loop is done."""
81+
# This is a simple stopping condition to check and see if we've
82+
# run all the trials.
83+
return len(self._results) >= self._max_iterations
84+
85+
def assign_trial_runner(self, trial_id: int, suggestion: int) -> None:
86+
"""Stub assign_trial_runner."""
87+
# In a real system we'd have a more sophisticated way of assigning
88+
# trials to TrialRunners.
89+
# Here we just round-robin the suggestions to the available TrialRunners.
90+
next_runner_id = self._current_runner_id
91+
self._current_runner_id = (self._current_runner_id + 1) % len(self._trial_runners)
92+
self._trial_schedule[trial_id] = (next_runner_id, suggestion)
93+
print(
94+
f"Assigned trial {trial_id} to runner {next_runner_id} with suggestion {suggestion}",
95+
flush=True,
96+
)
97+
98+
def schedule_new_trials(self, num_new_trials: int = 1) -> None:
99+
"""Stub schedule_new_trial(s)."""
100+
101+
# Accept more than one new suggestion at a time to simulate a real
102+
# system that might be doing multi-objective pareto frontier
103+
# optimization.
104+
105+
while num_new_trials > 0 and not self.is_done_scheduling():
106+
# Generate one (or more) new suggestion(s).
107+
# In the real system we'd get these from the Optimizer.
108+
suggestion = random.randint(0, 100)
109+
110+
# Note: it might be also be the case that we want to repeat that
111+
# suggestion multiple times on different TrialRunners.
112+
113+
# Schedule it to a TrialRunner.
114+
next_trial_id = self.get_last_trial_id() + 1
115+
self.assign_trial_runner(next_trial_id, suggestion)
116+
num_new_trials -= 1
117+
118+
def _run_trial_failed_callback(self, obj: Any) -> None: # pylint: disable=no-self-use
119+
"""Stub callback to run when run_trial fails in pool process."""
120+
raise RuntimeError(f"Trial failed: {obj}")
121+
122+
def _run_trial_finished_callback(self, result: dict[str, int | float]) -> None:
123+
"""Stub callback to run when run_trial finishes in pool process."""
124+
125+
# Store the result of the trial.
126+
trial_id = result["iteration"]
127+
assert isinstance(trial_id, int)
128+
self._results[trial_id] = result
129+
130+
# Remove it from the schedule.
131+
self._trial_schedule.pop(trial_id)
132+
133+
# And mark the TrialRunner as available.
134+
runner_id = result["runner_id"]
135+
assert isinstance(runner_id, int)
136+
trial_runner_status = self._trial_runners_status.get(runner_id)
137+
assert isinstance(trial_runner_status, AsyncResult)
138+
# assert trial_runner_status.ready()
139+
self._trial_runners_status[runner_id] = None
140+
141+
print(f"Trial {trial_id} on {runner_id} callback is done.", flush=True)
142+
143+
# Schedule more trials.
144+
# Note: this would schedule additional trials everytime one completes.
145+
# An alternative option would be to batch them up and schedule several
146+
# after a few complete.
147+
# The tradeoffs being model retraining time vs. waiting on straggler
148+
# workers vs. optimizer new suggestion accuracy.
149+
# Moreover, we need to handle the edge case and include scheduling in
150+
# the loop anyways, so it's probably better to just leave it all there.
151+
# self.schedule_new_trials(num_new_trials=1)
152+
153+
def get_idle_trial_runners_count(self) -> int:
154+
"""Stub get_idle_trial_runners_count."""
155+
return len([x for x in self._trial_runners_status.values() if x is None])
156+
157+
def start_optimization_loop(self) -> None:
158+
"""Stub start_optimization_loop."""
159+
160+
# Create a pool of processes to run the trials in parallel.
161+
with Pool(processes=len(self._trial_runners), maxtasksperchild=1) as pool:
162+
while not self.is_done_scheduling() or not self.is_done_running():
163+
# Run any existing trials that aren't currently running.
164+
# Do this first in case we're resuming from a previous run
165+
# (e.g., the real system will have remembered which Trials were
166+
# in progress by reloading them from the Storage backend).
167+
168+
# Avoid modifying the dictionary while iterating over it.
169+
trial_schedule = self._trial_schedule.copy()
170+
for trial_id, (runner_id, suggestion) in trial_schedule.items():
171+
# Skip trials that are already running on their assigned TrialRunner.
172+
if self._trial_runners_status[runner_id] is not None:
173+
continue
174+
# Else, start the Trial on the given TrialRunner in the background.
175+
self._trial_runners_status[runner_id] = pool.apply_async(
176+
TrialRunner(runner_id).run_trial,
177+
args=(trial_id, suggestion),
178+
callback=self._run_trial_finished_callback,
179+
error_callback=self._run_trial_failed_callback,
180+
)
181+
# Now all the available TrialRunners that had work to do should be running.
182+
183+
# Wait a moment to check if we have any idle TrialRunners.
184+
# This also allows us a chance to collect multiple results from
185+
# the pool before suggesting new ones.
186+
while len(self._trial_schedule) > 0 and self.get_idle_trial_runners_count() == 0:
187+
# Make the polling interval here configurable.
188+
sleep(0.5)
189+
190+
# Schedule more trials if we can.
191+
self.schedule_new_trials(num_new_trials=self.get_idle_trial_runners_count() or 1)
192+
193+
# Should be all done starting new trials.
194+
print("Closing the pool.", flush=True)
195+
pool.close()
196+
197+
print("Waiting for all trials to finish.", flush=True)
198+
# FIXME: This sometimes hangs. Not sure why yet.
199+
pool.join()
200+
201+
print("Optimization loop is done.", flush=True)
202+
print("results: " + json.dumps(self._results, indent=2))
203+
print("trial_schedule: " + json.dumps(self._trial_schedule, indent=2))
204+
print("trial_runner_status: " + json.dumps(self._trial_runners_status, indent=2))
205+
assert len(self._results) == self._max_iterations, "Unexpected number of trials run."
206+
assert not self._trial_schedule, "Some scheduled trials were not started."
207+
assert all(
208+
x is None for x in self._trial_runners_status.values()
209+
), "Some TrialRunners are still running."
210+
211+
212+
def main():
213+
"""Main function."""
214+
print("Starting ParallelTrialScheduler.", flush=True)
215+
scheduler = ParallelTrialScheduler(num_trial_runners=4, max_iterations=15)
216+
scheduler.start_optimization_loop()
217+
218+
219+
if __name__ == "__main__":
220+
main()

0 commit comments

Comments
 (0)