diff --git a/conftest.py b/conftest.py index fcdcaf2c..39cfdf91 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,6 @@ +import io import sys +from unittest import mock import pytest @@ -31,3 +33,25 @@ def pytest_ignore_collect(collection_path, path, config): if "django_project" in str(path): return True + + +@pytest.fixture(autouse=True, scope="module") +def mock_dot_write(request): + def open_effect( + filename, + mode="r", + *args, + **kwargs, + ): + if mode in ("r", "rt", "rb"): + return open(filename, mode, *args, **kwargs) + elif filename.startswith("/tmp/"): + return open(filename, mode, *args, **kwargs) + elif "b" in mode: + return io.BytesIO() + else: + return io.StringIO() + + with mock.patch("pydot.core.io.open", spec=True) as m: + m.side_effect = open_effect + yield m diff --git a/docs/_static/custom_machine.css b/docs/_static/custom_machine.css index 76c99cc6..7eebd625 100644 --- a/docs/_static/custom_machine.css +++ b/docs/_static/custom_machine.css @@ -28,6 +28,7 @@ } + /* Gallery Donwload buttons */ div.sphx-glr-download a { color: #404040 !important; diff --git a/docs/diagram.md b/docs/diagram.md index 2238bf46..91b72258 100644 --- a/docs/diagram.md +++ b/docs/diagram.md @@ -42,7 +42,7 @@ Graphviz. For example, on Debian-based systems (such as Ubuntu), you can use the >>> dot = graph() >>> dot.to_string() # doctest: +ELLIPSIS -'digraph list {... +'digraph OrderControl {... ``` diff --git a/docs/images/order_control_machine_initial.png b/docs/images/order_control_machine_initial.png index bd5cf06d..fed73a4e 100644 Binary files a/docs/images/order_control_machine_initial.png and b/docs/images/order_control_machine_initial.png differ diff --git a/docs/images/order_control_machine_processing.png b/docs/images/order_control_machine_processing.png index 5355f078..f4ddae7a 100644 Binary files a/docs/images/order_control_machine_processing.png and b/docs/images/order_control_machine_processing.png differ diff --git a/docs/images/readme_trafficlightmachine.png b/docs/images/readme_trafficlightmachine.png index f52ea2cc..2735e5f1 100644 Binary files a/docs/images/readme_trafficlightmachine.png and b/docs/images/readme_trafficlightmachine.png differ diff --git a/docs/images/test_state_machine_internal.png b/docs/images/test_state_machine_internal.png index 77806cdb..37404ce3 100644 Binary files a/docs/images/test_state_machine_internal.png and b/docs/images/test_state_machine_internal.png differ diff --git a/docs/states.md b/docs/states.md index 6b1d43e2..b127be35 100644 --- a/docs/states.md +++ b/docs/states.md @@ -17,6 +17,7 @@ How to define and attach [](actions.md) to {ref}`States`. A {ref}`StateMachine` should have one and only one `initial` {ref}`state`. +If not specified, the default initial state is the first child state in document order. The initial {ref}`state` is entered when the machine starts and the corresponding entering state {ref}`actions` are called if defined. diff --git a/statemachine/contrib/diagram.py b/statemachine/contrib/diagram.py index 694cae7d..9e434334 100644 --- a/statemachine/contrib/diagram.py +++ b/statemachine/contrib/diagram.py @@ -32,21 +32,34 @@ class DotGraphMachine: def __init__(self, machine): self.machine = machine - def _get_graph(self): - machine = self.machine + def _get_graph(self, machine): return pydot.Dot( - "list", + machine.name, graph_type="digraph", label=machine.name, fontname=self.font_name, fontsize=self.state_font_size, rankdir=self.graph_rankdir, + compound="true", ) - def _initial_node(self): + def _get_subgraph(self, state): + style = ", solid" + if state.parent and state.parent.parallel: + style = ", dashed" + subgraph = pydot.Subgraph( + label=f"{state.name}", + graph_name=f"cluster_{state.id}", + style=f"rounded{style}", + cluster="true", + ) + return subgraph + + def _initial_node(self, state): node = pydot.Node( - "i", - shape="circle", + self._state_id(state), + label="", + shape="point", style="filled", fontsize="1pt", fixedsize="true", @@ -56,14 +69,18 @@ def _initial_node(self): node.set_fillcolor("black") return node - def _initial_edge(self): + def _initial_edge(self, initial_node, state): + extra_params = {} + if state.states: + extra_params["lhead"] = f"cluster_{state.id}" return pydot.Edge( - "i", - self.machine.initial_state.id, + initial_node.get_name(), + self._state_id(state), label="", color="blue", fontname=self.font_name, fontsize=self.transition_font_size, + **extra_params, ) def _actions_getter(self): @@ -104,11 +121,18 @@ def _state_actions(self, state): return actions + @staticmethod + def _state_id(state): + if state.states: + return f"{state.id}_anchor" + else: + return state.id + def _state_as_node(self, state): actions = self._state_actions(state) node = pydot.Node( - state.id, + self._state_id(state), label=f"{state.name}{actions}", shape="rectangle", style="rounded, filled", @@ -127,29 +151,64 @@ def _transition_as_edge(self, transition): cond = ", ".join([str(cond) for cond in transition.cond]) if cond: cond = f"\n[{cond}]" + + extra_params = {} + has_substates = transition.source.states or transition.target.states + if transition.source.states: + extra_params["ltail"] = f"cluster_{transition.source.id}" + if transition.target.states: + extra_params["lhead"] = f"cluster_{transition.target.id}" + return pydot.Edge( - transition.source.id, - transition.target.id, + self._state_id(transition.source), + self._state_id(transition.target), label=f"{transition.event}{cond}", color="blue", fontname=self.font_name, fontsize=self.transition_font_size, + minlen=2 if has_substates else 1, + **extra_params, ) def get_graph(self): - graph = self._get_graph() - graph.add_node(self._initial_node()) - graph.add_edge(self._initial_edge()) + graph = self._get_graph(self.machine) + self._graph_states(self.machine, graph) + return graph - for state in self.machine.states: - graph.add_node(self._state_as_node(state)) - for transition in state.transitions: + def _graph_states(self, state, graph): + initial_node = self._initial_node(state) + initial_subgraph = pydot.Subgraph( + graph_name=f"{initial_node.get_name()}_initial", + label="", + peripheries=0, + margin=0, + ) + atomic_states_subgraph = pydot.Subgraph( + graph_name=f"cluster_{initial_node.get_name()}_atomic", + label="", + peripheries=0, + cluster="true", + ) + initial_subgraph.add_node(initial_node) + graph.add_subgraph(initial_subgraph) + graph.add_subgraph(atomic_states_subgraph) + + initial = next(s for s in state.states if s.initial) + graph.add_edge(self._initial_edge(initial_node, initial)) + + for substate in state.states: + if substate.states: + subgraph = self._get_subgraph(substate) + self._graph_states(substate, subgraph) + graph.add_subgraph(subgraph) + else: + atomic_states_subgraph.add_node(self._state_as_node(substate)) + + for transition in substate.transitions: if transition.internal: continue graph.add_edge(self._transition_as_edge(transition)) - return graph - def __call__(self): return self.get_graph() @@ -165,7 +224,8 @@ def quickchart_write_svg(sm: StateMachine, path: str): >>> from tests.examples.order_control_machine import OrderControl >>> sm = OrderControl() >>> print(sm._graph().to_string()) - digraph list { + digraph OrderControl { + compound=true; fontname=Arial; fontsize="10pt"; label=OrderControl; diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 84310836..bb363ed9 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -1,25 +1,17 @@ -from threading import Lock -from typing import TYPE_CHECKING -from weakref import proxy - from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import InvalidDefinition from ..exceptions import TransitionNotAllowed from ..i18n import _ from ..transition import Transition - -if TYPE_CHECKING: - from ..statemachine import StateMachine +from .base import BaseEngine -class AsyncEngine: - def __init__(self, sm: "StateMachine", rtc: bool = True): - self.sm = proxy(sm) - self._sentinel = object() +class AsyncEngine(BaseEngine): + def __init__(self, sm, rtc: bool = True): if not rtc: raise InvalidDefinition(_("Only RTC is supported on async engine")) - self._processing = Lock() + super().__init__(sm, rtc) async def activate_initial_state(self): """ @@ -63,8 +55,8 @@ async def processing_loop(self): first_result = self._sentinel try: # Execute the triggers in the queue in FIFO order until the queue is empty - while self.sm._external_queue: - trigger_data = self.sm._external_queue.popleft() + while self._external_queue: + trigger_data = self._external_queue.popleft() try: result = await self._trigger(trigger_data) if first_result is self._sentinel: @@ -72,7 +64,7 @@ async def processing_loop(self): except Exception: # Whe clear the queue as we don't have an expected behavior # and cannot keep processing - self.sm._external_queue.clear() + self._external_queue.clear() raise finally: self._processing.release() diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py new file mode 100644 index 00000000..0d3e8c67 --- /dev/null +++ b/statemachine/engines/base.py @@ -0,0 +1,34 @@ +from collections import deque +from threading import Lock +from typing import TYPE_CHECKING +from weakref import proxy + +from ..event_data import TriggerData + +if TYPE_CHECKING: + from ..statemachine import StateMachine + + +class BaseEngine: + def __init__(self, sm: "StateMachine", rtc: bool = True) -> None: + self.sm = proxy(sm) + self._external_queue: deque = deque() + self._sentinel = object() + self._rtc = rtc + self._processing = Lock() + self._put_initial_activation_trigger_on_queue() + + def _put_nonblocking(self, trigger_data: TriggerData): + """Put the trigger on the queue without blocking the caller.""" + self._external_queue.append(trigger_data) + + def _put_initial_activation_trigger_on_queue(self): + # Activate the initial state, this only works if the outer scope is sync code. + # for async code, the user should manually call `await sm.activate_initial_state()` + # after state machine creation. + if self.sm.current_state_value is None: + trigger_data = TriggerData( + machine=self.sm, + event="__initial__", + ) + self._put_nonblocking(trigger_data) diff --git a/statemachine/engines/statechart.py b/statemachine/engines/statechart.py new file mode 100644 index 00000000..f7be7a09 --- /dev/null +++ b/statemachine/engines/statechart.py @@ -0,0 +1,130 @@ +from ..event_data import EventData +from ..event_data import TriggerData +from ..exceptions import TransitionNotAllowed +from ..transition import Transition +from .base import BaseEngine + + +class StateChartEngine(BaseEngine): + def __init__(self, sm, rtc: bool = True): + super().__init__(sm, rtc) + self.activate_initial_state() + + def activate_initial_state(self): + """ + Activate the initial state. + + Called automatically on state machine creation from sync code, but in + async code, the user must call this method explicitly. + + Given how async works on python, there's no built-in way to activate the initial state that + may depend on async code from the StateMachine.__init__ method. + """ + return self.processing_loop() + + def processing_loop(self): + """Process event triggers. + + The simplest implementation is the non-RTC (synchronous), + where the trigger will be run immediately and the result collected as the return. + + .. note:: + + While processing the trigger, if others events are generated, they + will also be processed immediately, so a "nested" behavior happens. + + If the machine is on ``rtc`` model (queued), the event is put on a queue, and only the + first event will have the result collected. + + .. note:: + While processing the queue items, if others events are generated, they + will be processed sequentially (and not nested). + + """ + if not self._rtc: + # The machine is in "synchronous" mode + trigger_data = self._external_queue.popleft() + return self._trigger(trigger_data) + + # We make sure that only the first event enters the processing critical section, + # next events will only be put on the queue and processed by the same loop. + if not self._processing.acquire(blocking=False): + return None + + # We will collect the first result as the processing result to keep backwards compatibility + # so we need to use a sentinel object instead of `None` because the first result may + # be also `None`, and on this case the `first_result` may be overridden by another result. + first_result = self._sentinel + try: + # Execute the triggers in the queue in FIFO order until the queue is empty + while self._external_queue: + trigger_data = self._external_queue.popleft() + try: + result = self._trigger(trigger_data) + if first_result is self._sentinel: + first_result = result + except Exception: + # Whe clear the queue as we don't have an expected behavior + # and cannot keep processing + self._external_queue.clear() + raise + finally: + self._processing.release() + return first_result if first_result is not self._sentinel else None + + def _trigger(self, trigger_data: TriggerData): + event_data = None + if trigger_data.event == "__initial__": + transition = Transition(None, self.sm._get_initial_state(), event="__initial__") + transition._specs.clear() + event_data = EventData(trigger_data=trigger_data, transition=transition) + self._activate(event_data) + return self._sentinel + + state = self.sm.current_state + for transition in state.transitions: + if not transition.match(trigger_data.event): + continue + + event_data = EventData(trigger_data=trigger_data, transition=transition) + args, kwargs = event_data.args, event_data.extended_kwargs + self.sm._get_callbacks(transition.validators.key).call(*args, **kwargs) + if not self.sm._get_callbacks(transition.cond.key).all(*args, **kwargs): + continue + + result = self._activate(event_data) + event_data.result = result + event_data.executed = True + break + else: + if not self.sm.allow_event_without_transition: + raise TransitionNotAllowed(trigger_data.event, state) + + return event_data.result if event_data else None + + def _activate(self, event_data: EventData): + args, kwargs = event_data.args, event_data.extended_kwargs + transition = event_data.transition + source = event_data.state + target = transition.target + + result = self.sm._get_callbacks(transition.before.key).call(*args, **kwargs) + if source is not None and not transition.internal: + self.sm._get_callbacks(source.exit.key).call(*args, **kwargs) + + result += self.sm._get_callbacks(transition.on.key).call(*args, **kwargs) + + self.sm.current_state = target + event_data.state = target + kwargs["state"] = target + + if not transition.internal: + self.sm._get_callbacks(target.enter.key).call(*args, **kwargs) + self.sm._get_callbacks(transition.after.key).call(*args, **kwargs) + + if len(result) == 0: + result = None + elif len(result) == 1: + result = result[0] + + return result diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index c89c7570..74697b00 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -1,22 +1,13 @@ -from threading import Lock -from typing import TYPE_CHECKING -from weakref import proxy - from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed from ..transition import Transition - -if TYPE_CHECKING: - from ..statemachine import StateMachine +from .base import BaseEngine -class SyncEngine: - def __init__(self, sm: "StateMachine", rtc: bool = True): - self.sm = proxy(sm) - self._sentinel = object() - self._rtc = rtc - self._processing = Lock() +class SyncEngine(BaseEngine): + def __init__(self, sm, rtc: bool = True): + super().__init__(sm, rtc) self.activate_initial_state() def activate_initial_state(self): @@ -52,7 +43,7 @@ def processing_loop(self): """ if not self._rtc: # The machine is in "synchronous" mode - trigger_data = self.sm._external_queue.popleft() + trigger_data = self._external_queue.popleft() return self._trigger(trigger_data) # We make sure that only the first event enters the processing critical section, @@ -66,8 +57,8 @@ def processing_loop(self): first_result = self._sentinel try: # Execute the triggers in the queue in FIFO order until the queue is empty - while self.sm._external_queue: - trigger_data = self.sm._external_queue.popleft() + while self._external_queue: + trigger_data = self._external_queue.popleft() try: result = self._trigger(trigger_data) if first_result is self._sentinel: @@ -75,7 +66,7 @@ def processing_loop(self): except Exception: # Whe clear the queue as we don't have an expected behavior # and cannot keep processing - self.sm._external_queue.clear() + self._external_queue.clear() raise finally: self._processing.release() diff --git a/statemachine/factory.py b/statemachine/factory.py index 40e3db15..b93b9db6 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -10,6 +10,7 @@ from .event import Event from .event import trigger_event_factory from .exceptions import InvalidDefinition +from .graph import iterate_states from .graph import iterate_states_and_transitions from .graph import visit_connected_states from .i18n import _ @@ -32,6 +33,7 @@ def __init__( super().__init__(name, bases, attrs) registry.register(cls) cls.name = cls.__name__ + cls.id = cls.name.lower() cls.states: States = States() cls.states_map: Dict[Any, State] = {} """Map of ``state.value`` to the corresponding :ref:`state`.""" @@ -43,9 +45,27 @@ def __init__( cls.add_inherited(bases) cls.add_from_attributes(attrs) + cls._unpack_builders_callbacks() + + if not cls.states: + return + + cls._initials_by_document_order(cls.states) + + initials = [s for s in cls.states if s.initial] + parallels = [s.id for s in cls.states if s.parallel] + root_only_has_parallels = len(cls.states) == len(parallels) + + if len(initials) != 1 and not root_only_has_parallels: + raise InvalidDefinition( + _( + "There should be one and only one initial state. " + "Your currently have these: {0}" + ).format(", ".join(s.id for s in initials)) + ) try: - cls.initial_state: State = next(s for s in cls.states if s.initial) + cls.initial_state: State = next(s for s in initials if s.initial) except StopIteration: cls.initial_state = None # Abstract SM still don't have states @@ -59,6 +79,16 @@ def __init__( def __getattr__(self, attribute: str) -> Any: ... + def _initials_by_document_order(cls, states): + has_initial = False + for s in states: + cls._initials_by_document_order(s.states) + if s.initial: + has_initial = True + break + if not has_initial and states: + states[0]._initial = True + def _check(cls): has_states = bool(cls.states) has_events = bool(cls._events) @@ -168,6 +198,15 @@ def _setup(cls): "send", } | {s.id for s in cls.states} + def _unpack_builders_callbacks(cls): + callbacks = {} + for state in iterate_states(cls.states): + if state._callbacks: + callbacks.update(state._callbacks) + del state._callbacks + for key, value in callbacks.items(): + setattr(cls, key, value) + def add_inherited(cls, bases): for base in bases: for state in getattr(base, "states", []): @@ -205,15 +244,19 @@ def _add_unbounded_callback(cls, attr_name, func): def add_state(cls, id, state: State): state._set_id(id) - cls.states.append(state) - cls.states_map[state.value] = state - if not hasattr(cls, id): - setattr(cls, id, state) + if not state.parent: + cls.states.append(state) + cls.states_map[state.value] = state + if not hasattr(cls, id): + setattr(cls, id, state) # also register all events associated directly with transitions for event in state.transitions.unique_events: cls.add_event(event) + for substate in state.states: + cls.add_state(substate.id, substate) + def add_event(cls, event, transitions=None): if transitions is not None: transitions.add_event(event) diff --git a/statemachine/graph.py b/statemachine/graph.py index ef3c013a..3d6598c6 100644 --- a/statemachine/graph.py +++ b/statemachine/graph.py @@ -18,3 +18,12 @@ def iterate_states_and_transitions(states): for state in states: yield state yield from state.transitions + if state.states: + yield from iterate_states_and_transitions(state.states) + + +def iterate_states(states): + for state in states: + yield state + if state.states: + yield from iterate_states(state.states) diff --git a/statemachine/state.py b/statemachine/state.py index 4f67d5cd..2ee29e61 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -15,6 +15,27 @@ from .statemachine import StateMachine +class NestedStateFactory(type): + def __new__( # type: ignore [misc] + cls, classname, bases, attrs, name=None, **kwargs + ) -> "State": + if not bases: + return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value] + + states = [] + callbacks = {} + for key, value in attrs.items(): + if isinstance(value, State): + value._set_id(key) + states.append(value) + elif isinstance(value, TransitionList): + value.add_event(key) + elif callable(value): + callbacks[key] = value + + return State(name=name, states=states, _callbacks=callbacks, **kwargs) + + class State: """ A State in a :ref:`StateMachine` describes a particular behavior of the machine. @@ -33,6 +54,7 @@ class State: value. initial: Set ``True`` if the ``State`` is the initial one. There must be one and only one initial state in a statemachine. Defaults to ``False``. + If not specified, the default initial state is the first child state in document order. final: Set ``True`` if represents a final state. A machine can have optionally many final states. Final states have no :ref:`transition` starting from It. Defaults to ``False``. @@ -94,20 +116,50 @@ class State: """ + class Builder(metaclass=NestedStateFactory): + # Mimic the :ref:`State` public API to help linters discover the result of the Builder + # class. + + @classmethod + def to(cls, *args: "State", **kwargs) -> "TransitionList": # pragma: no cover + """Create transitions to the given target states. + + .. note: This method is only a type hint for mypy. + The actual implementation belongs to the :ref:`State` class. + """ + return TransitionList() + + @classmethod + def from_(cls, *args: "State", **kwargs) -> "TransitionList": # pragma: no cover + """Create transitions from the given target states (reversed). + + .. note: This method is only a type hint for mypy. + The actual implementation belongs to the :ref:`State` class. + """ + return TransitionList() + def __init__( self, name: str = "", value: Any = None, initial: bool = False, final: bool = False, + parallel: bool = False, + states: Any = None, enter: Any = None, exit: Any = None, + _callbacks: Any = None, ): self.name = name self.value = value + self.parallel = parallel + self.states = states or [] + self.is_atomic = bool(not self.states) self._initial = initial self._final = final self._id: str = "" + self._callbacks = _callbacks + self.parent: "State" = None self.transitions = TransitionList() self._specs = CallbackSpecList() self.enter = self._specs.grouper(CallbackGroup.ENTER).add( @@ -116,6 +168,12 @@ def __init__( self.exit = self._specs.grouper(CallbackGroup.EXIT).add( exit, priority=CallbackPriority.INLINE ) + self._init_states() + + def _init_states(self): + for state in self.states: + state.parent = self + setattr(self, state.id, state) def __eq__(self, other): return isinstance(other, State) and self.name == other.name and self.id == other.id @@ -217,6 +275,7 @@ def __init__( ): self._state = ref(state) self._machine = ref(machine) + self._init_states() @property def name(self): @@ -262,3 +321,15 @@ def id(self) -> str: @property def is_active(self): return self._machine().current_state == self + + @property + def is_atomic(self): + return self._state().is_atomic + + @property + def parent(self): + return self._state().parent + + @property + def states(self): + return self._state().states diff --git a/statemachine/statemachine.py b/statemachine/statemachine.py index a2e750aa..7a5c376c 100644 --- a/statemachine/statemachine.py +++ b/statemachine/statemachine.py @@ -1,5 +1,4 @@ import warnings -from collections import deque from copy import deepcopy from functools import partial from inspect import isawaitable @@ -17,6 +16,7 @@ from .dispatcher import Listener from .dispatcher import Listeners from .engines.async_ import AsyncEngine +from .engines.statechart import StateChartEngine from .engines.sync import SyncEngine from .event import Event from .event_data import TriggerData @@ -83,7 +83,6 @@ def __init__( self.state_field = state_field self.start_value = start_value self.allow_event_without_transition = allow_event_without_transition - self._external_queue: deque = deque() self._callbacks_registry = CallbacksRegistry() self._states_for_instance: Dict[State, State] = {} @@ -94,22 +93,13 @@ def __init__( raise InvalidDefinition(_("There are no states or transitions.")) self._register_callbacks(listeners or []) - - # Activate the initial state, this only works if the outer scope is sync code. - # for async code, the user should manually call `await sm.activate_initial_state()` - # after state machine creation. - if self.current_state_value is None: - trigger_data = TriggerData( - machine=self, - event="__initial__", - ) - self._put_nonblocking(trigger_data) - self._engine = self._get_engine(rtc) def _get_engine(self, rtc: bool): if self._callbacks_registry.has_async_callbacks: return AsyncEngine(self, rtc=rtc) + elif any(bool(s.states) for s in self.states): + return StateChartEngine(self, rtc=rtc) else: return SyncEngine(self, rtc=rtc) @@ -305,7 +295,7 @@ def allowed_events(self): def _put_nonblocking(self, trigger_data: TriggerData): """Put the trigger on the queue without blocking the caller.""" - self._external_queue.append(trigger_data) + self._engine._put_nonblocking(trigger_data) def send(self, event: str, *args, **kwargs): """Send an :ref:`Event` to the state machine. diff --git a/statemachine/states.py b/statemachine/states.py index 1f5c2257..d91e702e 100644 --- a/statemachine/states.py +++ b/statemachine/states.py @@ -61,6 +61,9 @@ def __getattr__(self, name: str): def __len__(self): return len(self._states) + def __getitem__(self, index): + return list(self)[index] + def __iter__(self): return iter(self._states.values()) diff --git a/tests/examples/microwave_inheritance_machine.py b/tests/examples/microwave_inheritance_machine.py new file mode 100644 index 00000000..47aa5071 --- /dev/null +++ b/tests/examples/microwave_inheritance_machine.py @@ -0,0 +1,68 @@ +""" +Microwave machine +================= + +Example that exercises the Compound and Parallel states. + +Compound +-------- + +If there are more than one substates, one of them is usually designated as the initial state of +that compound state. + +When a compound state is active, its substates behave as though they were an active state machine: + Exactly one child state must also be active. This means that: + +When a compound state is entered, it must also enter exactly one of its substates, usually its +initial state. +When an event happens, the substates have priority when it comes to selecting which transition to +follow. If a substate happens to handles an event, the event is consumed, it isn’t passed to the +parent compound state. +When a substate transitions to another substate, both “inside” the compound state, the compound +state does not exit or enter; it remains active. +When a compound state exits, its substate is simultaneously exited too. (Technically, the substate +exits first, then its parent.) +Compound states may be nested, or include parallel states. + +The opposite of a compound state is an atomic state, which is a state with no substates. + +A compound state is allowed to define transitions to its child states. Normally, when a transition +leads from a state, it causes that state to be exited. For transitions from a compound state to +one of its descendants, it is possible to define a transition that avoids exiting and entering +the compound state itself, such transitions are called local transitions. + + +""" +from statemachine import State +from statemachine import StateMachine + + +class MicroWave(StateMachine): + class oven(State.Builder, name="Microwave oven", parallel=True): + class engine(State.Builder): + off = State("Off", initial=True) + + class on(State.Builder): + idle = State("Idle", initial=True) + cooking = State("Cooking") + + idle.to(cooking, cond="closed.is_active") + cooking.to(idle, cond="open.is_active") + cooking.to.itself(internal=True, on="increment_timer") + + turn_off = on.to(off) + turn_on = off.to(on) + on.to(off, cond="cook_time_is_over") # eventless transition + + class door(State.Builder): + closed = State(initial=True) + open = State() + + door_open = closed.to(open) + door_close = open.to(closed) + + def __init__(self): + self.cook_time = 5 + self.door_closed = True + self.timer = 0 + super().__init__() diff --git a/tests/examples/traffic_light_nested_machine.py b/tests/examples/traffic_light_nested_machine.py new file mode 100644 index 00000000..33220680 --- /dev/null +++ b/tests/examples/traffic_light_nested_machine.py @@ -0,0 +1,64 @@ +""" +Nested Traffic light machine +---------------------------- + +Demonstrates the concept of nested compound states. + +From this example on XState: https://xstate.js.org/docs/guides/hierarchical.html#api + +""" +import time + +from statemachine import State +from statemachine import StateMachine + + +class NestedTrafficLightMachine(StateMachine): + "A traffic light machine" + green = State(initial=True, enter="reset_elapsed") + yellow = State(enter="reset_elapsed") + + class red(State.Builder, enter="reset_elapsed"): + "Pedestrian states" + walk = State(initial=True) + wait = State() + stop = State() + blinking = State() + + ped_countdown = walk.to(wait) | wait.to(stop) + + timer = green.to(yellow) | yellow.to(red) | red.to(green) + power_outage = red.blinking.from_() + power_restored = red.from_() + + def __init__(self, seconds_to_turn_state=5, seconds_running=20): + self.seconds_to_turn_state = seconds_to_turn_state + self.seconds_running = seconds_running + super().__init__(allow_event_without_transition=True) + + def on_timer(self, event: str, source: State, target: State): + print(f".. Running {event} from {source.id} to {target.id}") + + def reset_elapsed(self, event: str, time: int = 0): + print(f"entering reset_elapsed from {event} with {time}") + self.last_turn = time + + @timer.cond + def time_is_over(self, time): + return time - self.last_turn > self.seconds_to_turn_state + + def run_forever(self): + self.running = True + start_time = time.time() + while self.running: + print("tick!") + time.sleep(1) + curr_time = time.time() + self.send("timer", time=curr_time) + + if curr_time - start_time > self.seconds_running: + self.running = False + + +sm = NestedTrafficLightMachine() +sm.send("anything") diff --git a/tests/test_compound.py b/tests/test_compound.py new file mode 100644 index 00000000..f26044a7 --- /dev/null +++ b/tests/test_compound.py @@ -0,0 +1,108 @@ +import pytest + +from statemachine import State +from statemachine.statemachine import StateMachine + + +@pytest.fixture() +def microwave_cls(): + from tests.examples.microwave_inheritance_machine import MicroWave + + return MicroWave + + +def assert_state(s, name, initial=False, final=False, parallel=False, substates=None): + if substates is None: + substates = [] + + assert isinstance(s, State) + assert s.name == name + assert s.initial is initial + assert s.final is final + assert s.parallel is parallel + assert isinstance(s, State) + assert set(s.states) == set(substates) + + +class TestNestedSyntax: + def test_capture_constructor_arguments(self, microwave_cls): + sm = microwave_cls() + + assert_state( + sm.oven, + "Microwave oven", + parallel=True, + substates=[sm.oven.engine, sm.oven.door], + ) + assert_state( + sm.oven.engine, + "Engine", + initial=False, + substates=[sm.oven.engine.on, sm.oven.engine.off], + ) + assert_state(sm.oven.engine.off, "Off", initial=True) + assert_state( + sm.oven.engine.on, + "On", + substates=[sm.oven.engine.on.idle, sm.oven.engine.on.cooking], + ) + assert_state( + sm.oven.door, + "Door", + initial=False, + substates=[sm.oven.door.closed, sm.oven.door.open], + ) + assert_state(sm.oven.door.closed, "Closed", initial=True) + assert_state(sm.oven.door.open, "Open") + + def test_list_children_states(self, microwave_cls): + sm = microwave_cls() + assert [s.id for s in sm.oven.engine.states] == ["off", "on"] + + def test_list_events(self, microwave_cls): + sm = microwave_cls() + assert [e.name for e in sm.events] == [ + "turn_on", + "turn_off", + "door_open", + "door_close", + ] + + +class TestLCCAProperties: + def test_should_enter_initial_state(self, capsys): # noqa: C901 + class Machine(StateMachine): + class S(State.Builder): + class s1(State.Builder): + s11 = State(initial=True) + + def on_exit_s11(self): + print("leaving s11") + + def on_exit_s1(self): + print("leaving s1") + + class s2(State.Builder): + s21 = State(initial=True) + + def on_enter_s21(self): + print("entering s21") + + def on_enter_s2(self): + print("entering s2") + + def on_enter_s(self): + print("entering s") + + def on_exit_s(self): + print("leaving s") + + e = S.s1.to(S.s2.s21) + + def on_e(self): + print("executing transition") + + m = Machine() + m.send("e") + out, err = capsys.readouterr() + assert out == "leaving s11\nleaving s1\nexecuting transition\nentering s2\nentering s21\n" diff --git a/tests/test_contrib_diagram.py b/tests/test_contrib_diagram.py index b54a43d3..46cd6eef 100644 --- a/tests/test_contrib_diagram.py +++ b/tests/test_contrib_diagram.py @@ -47,7 +47,7 @@ def test_machine_dot(OrderControl): dot = graph() dot_str = dot.to_string() # or dot.to_string() - assert dot_str.startswith("digraph list {") + assert dot_str.startswith("digraph OrderControl {") class TestDiagramCmdLine: diff --git a/tests/test_nested.py b/tests/test_nested.py new file mode 100644 index 00000000..de96b979 --- /dev/null +++ b/tests/test_nested.py @@ -0,0 +1,5 @@ +def test_nested_sm(): + from tests.examples.microwave_inheritance_machine import MicroWave + + sm = MicroWave() + assert sm.current_state.id == "oven"