Skip to content

Commit 2fbaaca

Browse files
authored
Bug fix: restore trial_runner_id on reload from Storage (#969)
# Pull Request ## Title Bug fix: restore trial_runner_id on reload from Storage ______________________________________________________________________ ## Description @jsfreischuetz noticed that `pending_trials` was returning `Trial` objects without a `trial_runner_id` even though they were previously set. This change fixes that and closes #968. - Requires explicit `trial_runner_id` on `Trial` object instantiation - Adds it in places where it was missing. ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🧪 Tests ______________________________________________________________________ ## Testing New and existing tests. ______________________________________________________________________ ## Additional Notes (optional) Found in the course of work on #967 ______________________________________________________________________
1 parent e262652 commit 2fbaaca

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

mlos_bench/mlos_bench/storage/base_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def __init__( # pylint: disable=too-many-arguments
409409
experiment_id: str,
410410
trial_id: int,
411411
tunable_config_id: int,
412-
trial_runner_id: int | None = None,
412+
trial_runner_id: int | None,
413413
opt_targets: dict[str, Literal["min", "max"]],
414414
config: dict[str, Any] | None = None,
415415
status: Status = Status.UNKNOWN,

mlos_bench/mlos_bench/storage/sql/experiment.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor
276276
experiment_id=self._experiment_id,
277277
trial_id=trial.trial_id,
278278
config_id=trial.config_id,
279+
trial_runner_id=trial.trial_runner_id,
279280
opt_targets=self._opt_targets,
280281
config=config,
281282
)
@@ -350,6 +351,7 @@ def _new_trial(
350351
experiment_id=self._experiment_id,
351352
trial_id=self._trial_id,
352353
config_id=config_id,
354+
trial_runner_id=None, # initially, Trials are not assigned to a TrialRunner
353355
opt_targets=self._opt_targets,
354356
config=config,
355357
status=new_trial_status,

mlos_bench/mlos_bench/storage/sql/trial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__( # pylint: disable=too-many-arguments
3838
experiment_id: str,
3939
trial_id: int,
4040
config_id: int,
41-
trial_runner_id: int | None = None,
41+
trial_runner_id: int | None,
4242
opt_targets: dict[str, Literal["min", "max"]],
4343
config: dict[str, Any] | None = None,
4444
status: Status = Status.UNKNOWN,

mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,20 @@ def test_schedule_trial(
4747
# Schedule 2 hours in the future:
4848
trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config)
4949

50+
# Check that if we assign a TrialRunner that that value is still available on restore.
51+
trial_now2.set_trial_runner(1)
52+
assert trial_now2.trial_runner_id
53+
5054
exp_data = storage.experiments[exp_storage.experiment_id]
5155
trial_now1_data = exp_data.trials[trial_now1.trial_id]
5256
assert trial_now1_data.trial_runner_id is None
5357
assert trial_now1_data.status == Status.PENDING
5458
# Check that Status matches in object vs. backend storage.
5559
assert trial_now1.status == trial_now1_data.status
5660

61+
trial_now2_data = exp_data.trials[trial_now2.trial_id]
62+
assert trial_now2_data.trial_runner_id == trial_now2.trial_runner_id
63+
5764
# Scheduler side: get trials ready to run at certain timestamps:
5865

5966
# Pretend 1 minute has passed, get trials scheduled to run:
@@ -63,6 +70,16 @@ def test_schedule_trial(
6370
trial_now2.trial_id,
6471
}
6572

73+
# Make sure that the pending trials and trial_runner_ids match.
74+
pending_trial_runner_ids = {
75+
pending_trial.trial_id: pending_trial.trial_runner_id
76+
for pending_trial in exp_storage.pending_trials(timestamp + timedelta_1min, running=False)
77+
}
78+
assert pending_trial_runner_ids == {
79+
trial_now1.trial_id: trial_now1.trial_runner_id,
80+
trial_now2.trial_id: trial_now2.trial_runner_id,
81+
}
82+
6683
# Get trials scheduled to run within the next 1 hour:
6784
pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
6885
assert pending_ids == {

0 commit comments

Comments
 (0)