diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 56c7e98..b89a475 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,7 @@ ## New Features - +* We now provide the `DispatchManagingActor` class, a class to manage actors based on incoming dispatches. ## Bug Fixes diff --git a/pyproject.toml b/pyproject.toml index 017dc26..d4f0d65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,9 +39,9 @@ dependencies = [ # Make sure to update the version for cross-referencing also in the # mkdocs.yml file when changing the version here (look for the config key # plugins.mkdocstrings.handlers.python.import) - "frequenz-sdk >= 1.0.0-rc900, < 1.0.0-rc1000", - "frequenz-channels >= 1.1.0, < 2.0.0", - "frequenz-client-dispatch >= 0.6.0, < 0.7.0", + "frequenz-sdk == 1.0.0-rc900, < 1.0.0-rc1000", + "frequenz-channels >= 1.2.0, < 2.0.0", + "frequenz-client-dispatch >= 0.7.0, < 0.8.0", ] dynamic = ["version"] @@ -165,6 +165,7 @@ disable = [ [tool.pytest.ini_options] testpaths = ["tests", "src"] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" required_plugins = ["pytest-asyncio", "pytest-mock"] [tool.mypy] diff --git a/src/frequenz/dispatch/__init__.py b/src/frequenz/dispatch/__init__.py index 16df1b5..037665c 100644 --- a/src/frequenz/dispatch/__init__.py +++ b/src/frequenz/dispatch/__init__.py @@ -7,6 +7,8 @@ * [Dispatcher][frequenz.dispatch.Dispatcher]: The entry point for the API. * [Dispatch][frequenz.dispatch.Dispatch]: A dispatch type with lots of useful extra functionality. +* [DispatchManagingActor][frequenz.dispatch.DispatchManagingActor]: An actor to + manage other actors based on incoming dispatches. * [Created][frequenz.dispatch.Created], [Updated][frequenz.dispatch.Updated], [Deleted][frequenz.dispatch.Deleted]: Dispatch event types. @@ -16,6 +18,7 @@ from ._dispatch import Dispatch, RunningState from ._dispatcher import Dispatcher, ReceiverFetcher from ._event import Created, Deleted, DispatchEvent, Updated +from ._managing_actor import DispatchManagingActor, DispatchUpdate __all__ = [ "Created", @@ -26,4 +29,6 @@ "Updated", "Dispatch", "RunningState", + "DispatchManagingActor", + "DispatchUpdate", ] diff --git a/src/frequenz/dispatch/_dispatch.py b/src/frequenz/dispatch/_dispatch.py index cd67bd1..fab20c8 100644 --- a/src/frequenz/dispatch/_dispatch.py +++ b/src/frequenz/dispatch/_dispatch.py @@ -118,6 +118,13 @@ def running(self, type_: str) -> RunningState: return RunningState.STOPPED now = datetime.now(tz=timezone.utc) + + if now < self.start_time: + return RunningState.STOPPED + # A dispatch without duration is always running once it started + if self.duration is None: + return RunningState.RUNNING + if until := self._until(now): return RunningState.RUNNING if now < until else RunningState.STOPPED @@ -185,6 +192,7 @@ def next_run_after(self, after: datetime) -> datetime | None: if ( not self.recurrence.frequency or self.recurrence.frequency == Frequency.UNSPECIFIED + or self.duration is None # Infinite duration ): if after > self.start_time: return None @@ -236,7 +244,13 @@ def _until(self, now: datetime) -> datetime | None: Returns: The time when the dispatch should end or None if the dispatch is not running. + + Raises: + ValueError: If the dispatch has no duration. """ + if self.duration is None: + raise ValueError("_until: Dispatch has no duration") + if ( not self.recurrence.frequency or self.recurrence.frequency == Frequency.UNSPECIFIED diff --git a/src/frequenz/dispatch/_managing_actor.py b/src/frequenz/dispatch/_managing_actor.py new file mode 100644 index 0000000..e9a631b --- /dev/null +++ b/src/frequenz/dispatch/_managing_actor.py @@ -0,0 +1,180 @@ +# License: All rights reserved +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Helper class to manage actors based on dispatches.""" + +import logging +from dataclasses import dataclass +from typing import Any, Set + +from frequenz.channels import Receiver, Sender +from frequenz.client.dispatch.types import ComponentSelector +from frequenz.sdk.actor import Actor + +from ._dispatch import Dispatch, RunningState + +_logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, kw_only=True) +class DispatchUpdate: + """Event emitted when the dispatch changes.""" + + components: ComponentSelector + """Components to be used.""" + + dry_run: bool + """Whether this is a dry run.""" + + options: dict[str, Any] + """Additional options.""" + + +class DispatchManagingActor(Actor): + """Helper class to manage actors based on dispatches. + + Example usage: + + ```python + import os + import asyncio + from frequenz.dispatch import Dispatcher, DispatchManagingActor, DispatchUpdate + from frequenz.client.dispatch.types import ComponentSelector + from frequenz.client.common.microgrid.components import ComponentCategory + + from frequenz.channels import Receiver, Broadcast + + class MyActor(Actor): + def __init__(self, updates_channel: Receiver[DispatchUpdate]): + super().__init__() + self._updates_channel = updates_channel + self._dry_run: bool + self._options : dict[str, Any] + + async def _run(self) -> None: + while True: + update = await self._updates_channel.receive() + print("Received update:", update) + + self.set_components(update.components) + self._dry_run = update.dry_run + self._options = update.options + + def set_components(self, components: ComponentSelector) -> None: + match components: + case [int(), *_] as component_ids: + print("Dispatch: Setting components to %s", components) + case [ComponentCategory.BATTERY, *_]: + print("Dispatch: Using all battery components") + case unsupported: + print( + "Dispatch: Requested an unsupported selector %r, " + "but only component IDs or category BATTERY are supported.", + unsupported, + ) + + async def run(): + url = os.getenv("DISPATCH_API_URL", "grpc://fz-0004.frequenz.io:50051") + key = os.getenv("DISPATCH_API_KEY", "some-key") + + microgrid_id = 1 + + dispatcher = Dispatcher( + microgrid_id=microgrid_id, + server_url=url, + key=key + ) + + # Create update channel to receive dispatch update events pre-start and mid-run + dispatch_updates_channel = Broadcast[DispatchUpdate](name="dispatch_updates_channel") + + # Start actor and give it an dispatch updates channel receiver + my_actor = MyActor(dispatch_updates_channel.new_receiver()) + + status_receiver = dispatcher.running_status_change.new_receiver() + + managing_actor = DispatchManagingActor( + actor=my_actor, + dispatch_type="EXAMPLE", + running_status_receiver=status_receiver, + updates_sender=dispatch_updates_channel.new_sender(), + ) + + await asyncio.gather(dispatcher.start(), managing_actor.start()) + ``` + """ + + def __init__( + self, + actor: Actor | Set[Actor], + dispatch_type: str, + running_status_receiver: Receiver[Dispatch], + updates_sender: Sender[DispatchUpdate] | None = None, + ) -> None: + """Initialize the dispatch handler. + + Args: + actor: A set of actors or a single actor to manage. + dispatch_type: The type of dispatches to handle. + running_status_receiver: The receiver for dispatch running status changes. + updates_sender: The sender for dispatch events + """ + super().__init__() + self._dispatch_rx = running_status_receiver + self._actors = frozenset([actor] if isinstance(actor, Actor) else actor) + self._dispatch_type = dispatch_type + self._updates_sender = updates_sender + + def _start_actors(self) -> None: + """Start all actors.""" + for actor in self._actors: + if actor.is_running: + _logger.warning("Actor %s is already running", actor.name) + else: + actor.start() + + async def _stop_actors(self, msg: str) -> None: + """Stop all actors. + + Args: + msg: The message to be passed to the actors being stopped. + """ + for actor in self._actors: + if actor.is_running: + await actor.stop(msg) + else: + _logger.warning("Actor %s is not running", actor.name) + + async def _run(self) -> None: + """Wait for dispatches and handle them.""" + async for dispatch in self._dispatch_rx: + await self._handle_dispatch(dispatch=dispatch) + + async def _handle_dispatch(self, dispatch: Dispatch) -> None: + """Handle a dispatch. + + Args: + dispatch: The dispatch to handle. + """ + running = dispatch.running(self._dispatch_type) + match running: + case RunningState.STOPPED: + _logger.info("Stopped by dispatch %s", dispatch.id) + await self._stop_actors("Dispatch stopped") + case RunningState.RUNNING: + if self._updates_sender is not None: + _logger.info("Updated by dispatch %s", dispatch.id) + await self._updates_sender.send( + DispatchUpdate( + components=dispatch.selector, + dry_run=dispatch.dry_run, + options=dispatch.payload, + ) + ) + + _logger.info("Started by dispatch %s", dispatch.id) + self._start_actors() + case RunningState.DIFFERENT_TYPE: + _logger.debug( + "Unknown dispatch! Ignoring dispatch of type %s", dispatch.type + ) diff --git a/src/frequenz/dispatch/actor.py b/src/frequenz/dispatch/actor.py index 78b674e..9befa08 100644 --- a/src/frequenz/dispatch/actor.py +++ b/src/frequenz/dispatch/actor.py @@ -3,33 +3,20 @@ """The dispatch actor.""" -import asyncio import logging from datetime import datetime, timedelta, timezone +from heapq import heappop, heappush import grpc.aio -from frequenz.channels import Sender -from frequenz.channels.timer import SkipMissedAndDrift, Timer +from frequenz.channels import Sender, select, selected_from +from frequenz.channels.timer import SkipMissedAndResync, Timer from frequenz.client.dispatch import Client +from frequenz.client.dispatch.types import Event from frequenz.sdk.actor import Actor from ._dispatch import Dispatch, RunningState from ._event import Created, Deleted, DispatchEvent, Updated -_MAX_AHEAD_SCHEDULE = timedelta(hours=5) -"""The maximum time ahead to schedule a dispatch. - -We don't want to schedule dispatches too far ahead, -as they could start drifting if the delay is too long. - -This also prevents us from scheduling too many dispatches at once. - -The exact value is not important, but should be a few hours and not more than a day. -""" - -_DEFAULT_POLL_INTERVAL = timedelta(seconds=10) -"""The default interval to poll the API for dispatch changes.""" - _logger = logging.getLogger(__name__) """The logger for this module.""" @@ -50,7 +37,6 @@ def __init__( client: Client, lifecycle_updates_sender: Sender[DispatchEvent], running_state_change_sender: Sender[Dispatch], - poll_interval: timedelta = _DEFAULT_POLL_INTERVAL, ) -> None: """Initialize the actor. @@ -59,31 +45,101 @@ def __init__( client: The client to use for fetching dispatches. lifecycle_updates_sender: A sender for dispatch lifecycle events. running_state_change_sender: A sender for dispatch running state changes. - poll_interval: The interval to poll the API for dispatche changes. """ super().__init__(name="dispatch") self._client = client self._dispatches: dict[int, Dispatch] = {} - self._scheduled: dict[int, asyncio.Task[None]] = {} self._microgrid_id = microgrid_id self._lifecycle_updates_sender = lifecycle_updates_sender self._running_state_change_sender = running_state_change_sender - self._poll_timer = Timer(poll_interval, SkipMissedAndDrift()) + self._next_event_timer = Timer( + timedelta(seconds=100), SkipMissedAndResync(), auto_start=False + ) + """The timer to schedule the next event. + + Interval is chosen arbitrarily, as it will be reset on the first event. + """ + + self._scheduled_events: list[tuple[datetime, Dispatch]] = [] + """The scheduled events, sorted by time. + + Each event is a tuple of the scheduled time and the dispatch. + heapq is used to keep the list sorted by time, so the next event is + always at index 0. + """ async def _run(self) -> None: """Run the actor.""" - self._poll_timer.reset() - try: - async for _ in self._poll_timer: - await self._fetch() - except asyncio.CancelledError: - for task in self._scheduled.values(): - task.cancel() - raise + _logger.info("Starting dispatch actor for microgrid %s", self._microgrid_id) + + # Initial fetch + await self._fetch() + + stream = self._client.stream(microgrid_id=self._microgrid_id) + + # Streaming updates + async for selected in select(self._next_event_timer, stream): + if selected_from(selected, self._next_event_timer): + if not self._scheduled_events: + continue + _logger.debug( + "Executing scheduled event: %s", self._scheduled_events[0][1] + ) + await self._execute_scheduled_event(heappop(self._scheduled_events)[1]) + elif selected_from(selected, stream): + _logger.debug("Received dispatch event: %s", selected.message) + dispatch = Dispatch(selected.message.dispatch) + match selected.message.event: + case Event.CREATED: + self._dispatches[dispatch.id] = dispatch + await self._update_dispatch_schedule_and_notify(dispatch, None) + await self._lifecycle_updates_sender.send( + Created(dispatch=dispatch) + ) + case Event.UPDATED: + await self._update_dispatch_schedule_and_notify( + dispatch, self._dispatches[dispatch.id] + ) + self._dispatches[dispatch.id] = dispatch + await self._lifecycle_updates_sender.send( + Updated(dispatch=dispatch) + ) + case Event.DELETED: + self._dispatches.pop(dispatch.id) + await self._update_dispatch_schedule_and_notify(None, dispatch) + + dispatch._set_deleted() # pylint: disable=protected-access + await self._lifecycle_updates_sender.send( + Deleted(dispatch=dispatch) + ) + + async def _execute_scheduled_event(self, dispatch: Dispatch) -> None: + """Execute a scheduled event. + + Args: + dispatch: The dispatch to execute. + """ + await self._send_running_state_change(dispatch) + + # The timer is always a tiny bit delayed, so we need to check if the + # actor is supposed to be running now (we're assuming it wasn't already + # running, as all checks are done before scheduling) + if dispatch.running(dispatch.type) == RunningState.RUNNING: + # If it should be running, schedule the stop event + self._schedule_stop(dispatch) + # If the actor is not running, we need to schedule the next start + else: + self._schedule_start(dispatch) + + self._update_timer() async def _fetch(self) -> None: - """Fetch all relevant dispatches.""" + """Fetch all relevant dispatches using list. + + This is used for the initial fetch and for re-fetching all dispatches + if the connection was lost. + """ old_dispatches = self._dispatches self._dispatches = {} @@ -96,21 +152,20 @@ async def _fetch(self) -> None: self._dispatches[dispatch.id] = Dispatch(client_dispatch) old_dispatch = old_dispatches.pop(dispatch.id, None) if not old_dispatch: - self._update_dispatch_schedule(dispatch, None) _logger.info("New dispatch: %s", dispatch) + await self._update_dispatch_schedule_and_notify(dispatch, None) await self._lifecycle_updates_sender.send( Created(dispatch=dispatch) ) elif dispatch.update_time != old_dispatch.update_time: - self._update_dispatch_schedule(dispatch, old_dispatch) _logger.info("Updated dispatch: %s", dispatch) + await self._update_dispatch_schedule_and_notify( + dispatch, old_dispatch + ) await self._lifecycle_updates_sender.send( Updated(dispatch=dispatch) ) - if self._running_state_change(dispatch, old_dispatch): - await self._send_running_state_change(dispatch) - except grpc.aio.AioRpcError as error: _logger.error("Error fetching dispatches: %s", error) self._dispatches = old_dispatches @@ -118,21 +173,23 @@ async def _fetch(self) -> None: for dispatch in old_dispatches.values(): _logger.info("Deleted dispatch: %s", dispatch) - dispatch._set_deleted() # pylint: disable=protected-access await self._lifecycle_updates_sender.send(Deleted(dispatch=dispatch)) - if task := self._scheduled.pop(dispatch.id, None): - task.cancel() + await self._update_dispatch_schedule_and_notify(None, dispatch) - if self._running_state_change(None, dispatch): - await self._send_running_state_change(dispatch) + # Set deleted only here as it influences the result of dispatch.running() + # which is used in above in _running_state_change + dispatch._set_deleted() # pylint: disable=protected-access + await self._lifecycle_updates_sender.send(Deleted(dispatch=dispatch)) - def _update_dispatch_schedule( - self, dispatch: Dispatch, old_dispatch: Dispatch | None + async def _update_dispatch_schedule_and_notify( + self, dispatch: Dispatch | None, old_dispatch: Dispatch | None ) -> None: """Update the schedule for a dispatch. - Schedules, reschedules or cancels the dispatch based on the start_time - and active status. + Schedules, reschedules or cancels the dispatch events + based on the start_time and active status. + + Sends a running state change notification if necessary. For example: * when the start_time changes, the dispatch is rescheduled @@ -142,65 +199,107 @@ def _update_dispatch_schedule( dispatch: The dispatch to update the schedule for. old_dispatch: The old dispatch, if available. """ - if ( - old_dispatch - and old_dispatch.active - and old_dispatch.start_time != dispatch.start_time - ): - if task := self._scheduled.pop(dispatch.id, None): - task.cancel() + # If dispatch is None, the dispatch was deleted + # and we need to cancel any existing event for it + if not dispatch and old_dispatch: + self._remove_scheduled(old_dispatch) + + # If the dispatch was running, we need to notify + if old_dispatch.running(old_dispatch.type) == RunningState.RUNNING: + await self._send_running_state_change(old_dispatch) + + # A new dispatch was created + elif dispatch and not old_dispatch: + assert not self._remove_scheduled( + dispatch + ), "New dispatch already scheduled?!" + + # If its currently running, send notification right away + if dispatch.running(dispatch.type) == RunningState.RUNNING: + await self._send_running_state_change(dispatch) - if dispatch.active and dispatch.id not in self._scheduled: - self._scheduled[dispatch.id] = asyncio.create_task( - self._schedule_task(dispatch) - ) + self._schedule_stop(dispatch) + # Otherwise, if it's enabled but not yet running, schedule it + else: + self._schedule_start(dispatch) - async def _schedule_task(self, dispatch: Dispatch) -> None: - """Wait for a dispatch to become ready. + # Dispatch was updated + elif dispatch and old_dispatch: + # Remove potentially existing scheduled event + self._remove_scheduled(old_dispatch) - Waits for the dispatches next run and then notifies that it is ready. + # Check if the change requires an immediate notification + if self._update_changed_running_state(dispatch, old_dispatch): + await self._send_running_state_change(dispatch) - Args: - dispatch: The dispatch to schedule. - """ + if dispatch.running(dispatch.type) == RunningState.RUNNING: + self._schedule_stop(dispatch) + else: + self._schedule_start(dispatch) + + # We modified the schedule, so we need to reset the timer + self._update_timer() - def next_run_info() -> tuple[datetime, datetime] | None: - now = datetime.now(tz=timezone.utc) - next_run = dispatch.next_run_after(now) + def _update_timer(self) -> None: + """Update the timer to the next event.""" + if self._scheduled_events: + due_at: datetime = self._scheduled_events[0][0] + self._next_event_timer.reset(interval=due_at - datetime.now(timezone.utc)) + _logger.debug("Next event scheduled at %s", self._scheduled_events[0][0]) - if next_run is None: - return None + def _remove_scheduled(self, dispatch: Dispatch) -> bool: + """Remove a dispatch from the scheduled events. - return now, next_run + Args: + dispatch: The dispatch to remove. - while pair := next_run_info(): - now, next_time = pair + Returns: + True if the dispatch was found and removed, False otherwise. + """ + for idx, (_, sched_dispatch) in enumerate(self._scheduled_events): + if dispatch.id == sched_dispatch.id: + self._scheduled_events.pop(idx) + return True - if next_time - now > _MAX_AHEAD_SCHEDULE: - await asyncio.sleep(_MAX_AHEAD_SCHEDULE.total_seconds()) - continue + return False - _logger.info("Dispatch %s scheduled for %s", dispatch.id, next_time) - await asyncio.sleep((next_time - now).total_seconds()) + def _schedule_start(self, dispatch: Dispatch) -> None: + """Schedule a dispatch to start. - _logger.info("Dispatch %s executing...", dispatch) - await self._running_state_change_sender.send(dispatch) + Args: + dispatch: The dispatch to schedule. + """ + # If the dispatch is not active, don't schedule it + if not dispatch.active: + return - # Wait for the duration of the dispatch if set - if dispatch.duration: - _logger.info( - "Dispatch %s running for %s", dispatch.id, dispatch.duration + # Schedule the next run + try: + if next_run := dispatch.next_run: + heappush(self._scheduled_events, (next_run, dispatch)) + _logger.debug( + "Scheduled dispatch %s to start at %s", dispatch.id, next_run ) - await asyncio.sleep(dispatch.duration.total_seconds()) + else: + _logger.debug("Dispatch %s has no next run", dispatch.id) + except ValueError as error: + _logger.error("Error scheduling dispatch %s: %s", dispatch.id, error) - _logger.info("Dispatch %s runtime duration reached", dispatch.id) - await self._running_state_change_sender.send(dispatch) + def _schedule_stop(self, dispatch: Dispatch) -> None: + """Schedule a dispatch to stop. - _logger.info("Dispatch completed: %s", dispatch) - self._scheduled.pop(dispatch.id) - - def _running_state_change( - self, updated_dispatch: Dispatch | None, previous_dispatch: Dispatch | None + Args: + dispatch: The dispatch to schedule. + """ + # Setup stop timer if the dispatch has a duration + if dispatch.duration and dispatch.duration > timedelta(seconds=0): + until = dispatch.until + assert until is not None + heappush(self._scheduled_events, (until, dispatch)) + _logger.debug("Scheduled dispatch %s to stop at %s", dispatch, until) + + def _update_changed_running_state( + self, updated_dispatch: Dispatch, previous_dispatch: Dispatch ) -> bool: """Check if the running state of a dispatch has changed. @@ -212,29 +311,12 @@ def _running_state_change( in which case we need to send the message now. Args: - updated_dispatch: The new dispatch, if available. - previous_dispatch: The old dispatch, if available. + updated_dispatch: The new dispatch + previous_dispatch: The old dispatch Returns: True if the running state has changed, False otherwise. """ - # New dispatch - if previous_dispatch is None: - assert updated_dispatch is not None - - # Client was not informed about the dispatch, do it now - # pylint: disable=protected-access - if not updated_dispatch._running_status_notified: - return True - - # Deleted dispatch - if updated_dispatch is None: - assert previous_dispatch is not None - return ( - previous_dispatch.running(previous_dispatch.type) - == RunningState.RUNNING - ) - # If any of the runtime attributes changed, we need to send a message runtime_state_attributes = [ "running", diff --git a/tests/test_frequenz_dispatch.py b/tests/test_frequenz_dispatch.py index 45ba78c..303974e 100644 --- a/tests/test_frequenz_dispatch.py +++ b/tests/test_frequenz_dispatch.py @@ -15,18 +15,26 @@ from frequenz.client.dispatch.test.client import FakeClient, to_create_params from frequenz.client.dispatch.test.generator import DispatchGenerator from frequenz.client.dispatch.types import Dispatch as BaseDispatch -from frequenz.client.dispatch.types import Frequency +from frequenz.client.dispatch.types import Frequency, RecurrenceRule from pytest import fixture -from frequenz.dispatch import Created, Deleted, Dispatch, DispatchEvent, Updated +from frequenz.dispatch import ( + Created, + Deleted, + Dispatch, + DispatchEvent, + RunningState, + Updated, +) from frequenz.dispatch.actor import DispatchingActor -# This method replaces the event loop for all tests in the file. @fixture def event_loop_policy() -> async_solipsism.EventLoopPolicy: - """Return an event loop policy that uses the async solipsism event loop.""" - return async_solipsism.EventLoopPolicy() + """Set the event loop policy to use async_solipsism.""" + policy = async_solipsism.EventLoopPolicy() + asyncio.set_event_loop_policy(policy) + return policy @fixture @@ -51,7 +59,7 @@ class ActorTestEnv: """The actor under test.""" updated_dispatches: Receiver[DispatchEvent] """The receiver for updated dispatches.""" - ready_dispatches: Receiver[Dispatch] + running_state_change: Receiver[Dispatch] """The receiver for ready dispatches.""" client: FakeClient """The fake client for the actor.""" @@ -75,16 +83,16 @@ async def actor_env() -> AsyncIterator[ActorTestEnv]: ) actor.start() - - yield ActorTestEnv( - actor, - lifecycle_updates_dispatches.new_receiver(), - running_state_change_dispatches.new_receiver(), - client, - microgrid_id, - ) - - await actor.stop() + try: + yield ActorTestEnv( + actor=actor, + updated_dispatches=lifecycle_updates_dispatches.new_receiver(), + running_state_change=running_state_change_dispatches.new_receiver(), + client=client, + microgrid_id=microgrid_id, + ) + finally: + await actor.stop() @fixture @@ -124,7 +132,7 @@ def update_dispatch(sample: BaseDispatch, dispatch: BaseDispatch) -> BaseDispatc async def _test_new_dispatch_created( actor_env: ActorTestEnv, sample: BaseDispatch, -) -> BaseDispatch: +) -> Dispatch: """Test that a new dispatch is created. Args: @@ -142,10 +150,12 @@ async def _test_new_dispatch_created( case Deleted(dispatch) | Updated(dispatch): assert False, "Expected a created event" case Created(dispatch): - sample = update_dispatch(sample, dispatch) - assert dispatch == Dispatch(sample) + received = Dispatch(update_dispatch(sample, dispatch)) + received._set_running_status_notified() # pylint: disable=protected-access + dispatch._set_running_status_notified() # pylint: disable=protected-access + assert dispatch == received - return sample + return dispatch async def test_existing_dispatch_updated( @@ -166,7 +176,7 @@ async def test_existing_dispatch_updated( sample = await _test_new_dispatch_created(actor_env, sample) fake_time.shift(timedelta(seconds=1)) - await actor_env.client.update( + updated = await actor_env.client.update( microgrid_id=actor_env.microgrid_id, dispatch_id=sample.id, new_fields={ @@ -179,17 +189,10 @@ async def test_existing_dispatch_updated( dispatch_event = await actor_env.updated_dispatches.receive() match dispatch_event: case Created(dispatch) | Deleted(dispatch): - assert False, "Expected an updated event" + assert False, f"Expected an updated event, got {dispatch_event}" case Updated(dispatch): - sample = update_dispatch(sample, dispatch) - sample = replace( - sample, - active=True, - recurrence=replace(sample.recurrence, frequency=Frequency.UNSPECIFIED), - ) - assert dispatch == Dispatch( - sample, + updated, running_state_change_synced=dispatch.running_state_change_synced, ) @@ -200,9 +203,7 @@ async def test_existing_dispatch_deleted( fake_time: time_machine.Coordinates, ) -> None: """Test that an existing dispatch is deleted.""" - sample = generator.generate_dispatch() - - sample = await _test_new_dispatch_created(actor_env, sample) + sample = await _test_new_dispatch_created(actor_env, generator.generate_dispatch()) await actor_env.client.delete( microgrid_id=actor_env.microgrid_id, dispatch_id=sample.id @@ -210,14 +211,129 @@ async def test_existing_dispatch_deleted( fake_time.shift(timedelta(seconds=10)) await asyncio.sleep(10) - print("Awaiting deleted dispatch update") dispatch_event = await actor_env.updated_dispatches.receive() match dispatch_event: case Created(dispatch) | Updated(dispatch): assert False, "Expected a deleted event" case Deleted(dispatch): - sample = update_dispatch(sample, dispatch) - assert dispatch == Dispatch(sample, deleted=True) + sample._set_deleted() # pylint: disable=protected-access + dispatch._set_running_status_notified() # pylint: disable=protected-access + assert dispatch == sample + + +async def test_dispatch_inf_duration_deleted( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test that a dispatch with infinite duration can be deleted while running.""" + # Generate a dispatch with infinite duration (duration=None) + sample = generator.generate_dispatch() + sample = replace( + sample, active=True, duration=None, start_time=_now() + timedelta(seconds=5) + ) + # Create the dispatch + sample = await _test_new_dispatch_created(actor_env, sample) + # Advance time to when the dispatch should start + fake_time.shift(timedelta(seconds=40)) + await asyncio.sleep(40) + # Expect notification of the dispatch being ready to run + ready_dispatch = await actor_env.running_state_change.receive() + assert ready_dispatch.running(sample.type) == RunningState.RUNNING + + # Now delete the dispatch + await actor_env.client.delete( + microgrid_id=actor_env.microgrid_id, dispatch_id=sample.id + ) + fake_time.shift(timedelta(seconds=10)) + await asyncio.sleep(1) + # Expect notification to stop the dispatch + done_dispatch = await actor_env.running_state_change.receive() + assert done_dispatch.running(sample.type) == RunningState.STOPPED + + +async def test_dispatch_inf_duration_updated_stopped_started( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test that a dispatch with infinite duration can be stopped and started by updating it.""" + # Generate a dispatch with infinite duration (duration=None) + sample = generator.generate_dispatch() + sample = replace( + sample, active=True, duration=None, start_time=_now() + timedelta(seconds=5) + ) + # Create the dispatch + sample = await _test_new_dispatch_created(actor_env, sample) + # Advance time to when the dispatch should start + fake_time.shift(timedelta(seconds=40)) + await asyncio.sleep(40) + # Expect notification of the dispatch being ready to run + ready_dispatch = await actor_env.running_state_change.receive() + assert ready_dispatch.running(sample.type) == RunningState.RUNNING + + # Now update the dispatch to set active=False (stop it) + await actor_env.client.update( + microgrid_id=actor_env.microgrid_id, + dispatch_id=sample.id, + new_fields={"active": False}, + ) + fake_time.shift(timedelta(seconds=10)) + await asyncio.sleep(1) + # Expect notification to stop the dispatch + stopped_dispatch = await actor_env.running_state_change.receive() + assert stopped_dispatch.running(sample.type) == RunningState.STOPPED + + # Now update the dispatch to set active=True (start it again) + await actor_env.client.update( + microgrid_id=actor_env.microgrid_id, + dispatch_id=sample.id, + new_fields={"active": True}, + ) + fake_time.shift(timedelta(seconds=10)) + await asyncio.sleep(1) + # Expect notification of the dispatch being ready to run again + started_dispatch = await actor_env.running_state_change.receive() + assert started_dispatch.running(sample.type) == RunningState.RUNNING + + +async def test_dispatch_inf_duration_updated_to_finite_and_stops( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test updating an inf. duration changing to finite. + + Test that updating an infinite duration dispatch to a finite duration causes + it to stop if the duration has passed. + """ + # Generate a dispatch with infinite duration (duration=None) + sample = generator.generate_dispatch() + sample = replace( + sample, active=True, duration=None, start_time=_now() + timedelta(seconds=5) + ) + # Create the dispatch + sample = await _test_new_dispatch_created(actor_env, sample) + # Advance time to when the dispatch should start + fake_time.shift(timedelta(seconds=10)) + await asyncio.sleep(1) + # Expect notification of the dispatch being ready to run + ready_dispatch = await actor_env.running_state_change.receive() + assert ready_dispatch.running(sample.type) == RunningState.RUNNING + + # Update the dispatch to set duration to a finite duration that has already passed + # The dispatch has been running for 5 seconds; set duration to 5 seconds + await actor_env.client.update( + microgrid_id=actor_env.microgrid_id, + dispatch_id=sample.id, + new_fields={"duration": timedelta(seconds=5)}, + ) + # Advance time to allow the update to be processed + fake_time.shift(timedelta(seconds=1)) + await asyncio.sleep(1) + # Expect notification to stop the dispatch because the duration has passed + stopped_dispatch = await actor_env.running_state_change.receive() + assert stopped_dispatch.running(sample.type) == RunningState.STOPPED async def test_dispatch_schedule( @@ -226,7 +342,9 @@ async def test_dispatch_schedule( fake_time: time_machine.Coordinates, ) -> None: """Test that a random dispatch is scheduled correctly.""" - sample = generator.generate_dispatch() + sample = replace( + generator.generate_dispatch(), active=True, duration=timedelta(seconds=10) + ) await actor_env.client.create(**to_create_params(actor_env.microgrid_id, sample)) dispatch = Dispatch(actor_env.client.dispatches(actor_env.microgrid_id)[0]) @@ -237,14 +355,144 @@ async def test_dispatch_schedule( await asyncio.sleep(1) # Expect notification of the dispatch being ready to run - ready_dispatch = await actor_env.ready_dispatches.receive() + ready_dispatch = await actor_env.running_state_change.receive() + + # Set flag we expect to be different to compare the dispatch with the one received + dispatch._set_running_status_notified() # pylint: disable=protected-access assert ready_dispatch == dispatch + assert dispatch.duration is not None # Shift time to the end of the dispatch fake_time.shift(dispatch.duration + timedelta(seconds=1)) await asyncio.sleep(1) # Expect notification to stop the dispatch - done_dispatch = await actor_env.ready_dispatches.receive() + done_dispatch = await actor_env.running_state_change.receive() assert done_dispatch == dispatch + + +async def test_dispatch_inf_duration_updated_to_finite_and_continues( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test that updating an infinite duration dispatch to a finite duration. + + Test that updating an infinite duration dispatch to a finite + allows it to continue running if the duration hasn't passed. + """ + # Generate a dispatch with infinite duration (duration=None) + sample = generator.generate_dispatch() + sample = replace( + sample, active=True, duration=None, start_time=_now() + timedelta(seconds=5) + ) + # Create the dispatch + sample = await _test_new_dispatch_created(actor_env, sample) + # Advance time to when the dispatch should start + fake_time.shift(timedelta(seconds=10)) + await asyncio.sleep(1) + # Expect notification of the dispatch being ready to run + ready_dispatch = await actor_env.running_state_change.receive() + assert ready_dispatch.running(sample.type) == RunningState.RUNNING + + # Update the dispatch to set duration to a finite duration that hasn't passed yet + # The dispatch has been running for 5 seconds; set duration to 100 seconds + await actor_env.client.update( + microgrid_id=actor_env.microgrid_id, + dispatch_id=sample.id, + new_fields={"duration": timedelta(seconds=100)}, + ) + # Advance time slightly to process the update + fake_time.shift(timedelta(seconds=1)) + await asyncio.sleep(1) + # The dispatch should continue running + # Advance time until the total running time reaches 100 seconds + fake_time.shift(timedelta(seconds=94)) + await asyncio.sleep(1) + # Expect notification to stop the dispatch because the duration has now passed + stopped_dispatch = await actor_env.running_state_change.receive() + assert stopped_dispatch.running(sample.type) == RunningState.STOPPED + + +async def test_dispatch_new_but_finished( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test that a dispatch that is already finished is not started.""" + # Generate a dispatch that is already finished + finished_dispatch = generator.generate_dispatch() + finished_dispatch = replace( + finished_dispatch, + active=True, + duration=timedelta(seconds=5), + start_time=_now() - timedelta(seconds=50), + recurrence=RecurrenceRule(), + type="I_SHOULD_NEVER_RUN", + ) + # Create an old dispatch + actor_env.client.set_dispatches(actor_env.microgrid_id, [finished_dispatch]) + await actor_env.actor.stop() + actor_env.actor.start() + + # Create another dispatch the normal way + new_dispatch = generator.generate_dispatch() + new_dispatch = replace( + new_dispatch, + active=True, + duration=timedelta(seconds=10), + start_time=_now() + timedelta(seconds=5), + recurrence=RecurrenceRule(), + type="NEW_BETTER_DISPATCH", + ) + # Consume one lifecycle_updates event + await actor_env.updated_dispatches.receive() + new_dispatch = await _test_new_dispatch_created(actor_env, new_dispatch) + + # Advance time to when the new dispatch should still not start + fake_time.shift(timedelta(seconds=100)) + + assert await actor_env.running_state_change.receive() == new_dispatch + + +async def test_notification_on_actor_start( + actor_env: ActorTestEnv, + generator: DispatchGenerator, + fake_time: time_machine.Coordinates, +) -> None: + """Test that the actor sends notifications for all running dispatches on start.""" + # Generate a dispatch that is already running + running_dispatch = generator.generate_dispatch() + running_dispatch = replace( + running_dispatch, + active=True, + duration=timedelta(seconds=10), + start_time=_now() - timedelta(seconds=5), + recurrence=RecurrenceRule(), + type="I_SHOULD_RUN", + ) + # Generate a dispatch that is not running + stopped_dispatch = generator.generate_dispatch() + stopped_dispatch = replace( + stopped_dispatch, + active=False, + duration=timedelta(seconds=5), + start_time=_now() - timedelta(seconds=5), + recurrence=RecurrenceRule(), + type="I_SHOULD_NOT_RUN", + ) + await actor_env.actor.stop() + + # Create the dispatches + actor_env.client.set_dispatches( + actor_env.microgrid_id, [running_dispatch, stopped_dispatch] + ) + actor_env.actor.start() + + fake_time.shift(timedelta(seconds=1)) + await asyncio.sleep(1) + + # Expect notification of the running dispatch being ready to run + ready_dispatch = await actor_env.running_state_change.receive() + assert ready_dispatch.running(running_dispatch.type) == RunningState.RUNNING diff --git a/tests/test_mananging_actor.py b/tests/test_mananging_actor.py new file mode 100644 index 0000000..100e2a7 --- /dev/null +++ b/tests/test_mananging_actor.py @@ -0,0 +1,164 @@ +# LICENSE: ALL RIGHTS RESERVED +# Copyright © 2024 Frequenz Energy-as-a-Service GmbH + +"""Test the dispatch runner.""" + +import asyncio +from dataclasses import dataclass, replace +from datetime import datetime, timedelta, timezone +from typing import AsyncIterator, Iterator + +import async_solipsism +import time_machine +from frequenz.channels import Broadcast, Receiver, Sender +from frequenz.client.dispatch.test.generator import DispatchGenerator +from frequenz.client.dispatch.types import Frequency +from frequenz.sdk.actor import Actor +from pytest import fixture + +from frequenz.dispatch import Dispatch, DispatchManagingActor, DispatchUpdate + + +@fixture +def event_loop_policy() -> async_solipsism.EventLoopPolicy: + """Set the event loop policy to use async_solipsism.""" + policy = async_solipsism.EventLoopPolicy() + asyncio.set_event_loop_policy(policy) + return policy + + +@fixture +def fake_time() -> Iterator[time_machine.Coordinates]: + """Replace real time with a time machine that doesn't automatically tick.""" + # destination can be a datetime or a timestamp (int), so are moving to the + # epoch (in UTC!) + with time_machine.travel(destination=0, tick=False) as traveller: + yield traveller + + +def _now() -> datetime: + """Return the current time in UTC.""" + return datetime.now(tz=timezone.utc) + + +class MockActor(Actor): + """Mock actor for testing.""" + + async def _run(self) -> None: + while True: + await asyncio.sleep(1) + + +@dataclass +class TestEnv: + """Test environment.""" + + actor: Actor + runner_actor: DispatchManagingActor + running_status_sender: Sender[Dispatch] + updates_receiver: Receiver[DispatchUpdate] + generator: DispatchGenerator = DispatchGenerator() + + +@fixture +async def test_env() -> AsyncIterator[TestEnv]: + """Create a test environment.""" + channel = Broadcast[Dispatch](name="dispatch ready test channel") + updates_channel = Broadcast[DispatchUpdate](name="dispatch update test channel") + + actor = MockActor() + + runner_actor = DispatchManagingActor( + actor=actor, + dispatch_type="UNIT_TEST", + running_status_receiver=channel.new_receiver(), + updates_sender=updates_channel.new_sender(), + ) + + runner_actor.start() + + yield TestEnv( + actor=actor, + runner_actor=runner_actor, + running_status_sender=channel.new_sender(), + updates_receiver=updates_channel.new_receiver(), + ) + + await runner_actor.stop() + + +async def test_simple_start_stop( + test_env: TestEnv, + fake_time: time_machine.Coordinates, +) -> None: + """Test behavior when receiving start/stop messages.""" + now = _now() + duration = timedelta(minutes=10) + dispatch = test_env.generator.generate_dispatch() + dispatch = replace( + dispatch, + active=True, + dry_run=False, + duration=duration, + start_time=now, + payload={"test": True}, + type="UNIT_TEST", + recurrence=replace( + dispatch.recurrence, + frequency=Frequency.UNSPECIFIED, + ), + ) + + await test_env.running_status_sender.send(Dispatch(dispatch)) + fake_time.shift(timedelta(seconds=1)) + + event = await test_env.updates_receiver.receive() + assert event.options == {"test": True} + assert event.components == dispatch.selector + assert event.dry_run is False + + assert test_env.actor.is_running is True + + fake_time.shift(duration) + await test_env.running_status_sender.send(Dispatch(dispatch)) + + # Give await actor.stop a chance to run in DispatchManagingActor + await asyncio.sleep(0.1) + + assert test_env.actor.is_running is False + + +async def test_dry_run(test_env: TestEnv, fake_time: time_machine.Coordinates) -> None: + """Test the dry run mode.""" + dispatch = test_env.generator.generate_dispatch() + dispatch = replace( + dispatch, + dry_run=True, + active=True, + start_time=_now(), + duration=timedelta(minutes=10), + type="UNIT_TEST", + recurrence=replace( + dispatch.recurrence, + frequency=Frequency.UNSPECIFIED, + ), + ) + + await test_env.running_status_sender.send(Dispatch(dispatch)) + fake_time.shift(timedelta(seconds=1)) + + event = await test_env.updates_receiver.receive() + + assert event.dry_run is dispatch.dry_run + assert event.components == dispatch.selector + assert event.options == dispatch.payload + assert test_env.actor.is_running is True + + assert dispatch.duration is not None + fake_time.shift(dispatch.duration) + await test_env.running_status_sender.send(Dispatch(dispatch)) + + # Give await actor.stop a chance to run in DispatchManagingActor + await asyncio.sleep(0.1) + + assert test_env.actor.is_running is False