Skip to content

feat: Nested states (compound / parallel) #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import sys
from unittest import mock

import pytest

Expand Down Expand Up @@ -31,3 +33,25 @@ def pytest_ignore_collect(collection_path, path, config):

if "django_project" in str(path):
return True


@pytest.fixture(autouse=True, scope="module")
def mock_dot_write(request):
def open_effect(
filename,
mode="r",
*args,
**kwargs,
):
if mode in ("r", "rt", "rb"):
return open(filename, mode, *args, **kwargs)
elif filename.startswith("/tmp/"):
return open(filename, mode, *args, **kwargs)
elif "b" in mode:
return io.BytesIO()
else:
return io.StringIO()

with mock.patch("pydot.core.io.open", spec=True) as m:
m.side_effect = open_effect
yield m
1 change: 1 addition & 0 deletions docs/_static/custom_machine.css
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

}


/* Gallery Donwload buttons */
div.sphx-glr-download a {
color: #404040 !important;
Expand Down
2 changes: 1 addition & 1 deletion docs/diagram.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Graphviz. For example, on Debian-based systems (such as Ubuntu), you can use the
>>> dot = graph()

>>> dot.to_string() # doctest: +ELLIPSIS
'digraph list {...
'digraph OrderControl {...

```

Expand Down
Binary file modified docs/images/order_control_machine_initial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/order_control_machine_processing.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/readme_trafficlightmachine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/test_state_machine_internal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/states.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ How to define and attach [](actions.md) to {ref}`States`.

A {ref}`StateMachine` should have one and only one `initial` {ref}`state`.

If not specified, the default initial state is the first child state in document order.

The initial {ref}`state` is entered when the machine starts and the corresponding entering
state {ref}`actions` are called if defined.
Expand Down
102 changes: 81 additions & 21 deletions statemachine/contrib/diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,34 @@ class DotGraphMachine:
def __init__(self, machine):
self.machine = machine

def _get_graph(self):
machine = self.machine
def _get_graph(self, machine):
return pydot.Dot(
"list",
machine.name,
graph_type="digraph",
label=machine.name,
fontname=self.font_name,
fontsize=self.state_font_size,
rankdir=self.graph_rankdir,
compound="true",
)

def _initial_node(self):
def _get_subgraph(self, state):
style = ", solid"
if state.parent and state.parent.parallel:
style = ", dashed"
subgraph = pydot.Subgraph(
label=f"{state.name}",
graph_name=f"cluster_{state.id}",
style=f"rounded{style}",
cluster="true",
)
return subgraph

def _initial_node(self, state):
node = pydot.Node(
"i",
shape="circle",
self._state_id(state),
label="",
shape="point",
style="filled",
fontsize="1pt",
fixedsize="true",
Expand All @@ -56,14 +69,18 @@ def _initial_node(self):
node.set_fillcolor("black")
return node

def _initial_edge(self):
def _initial_edge(self, initial_node, state):
extra_params = {}
if state.states:
extra_params["lhead"] = f"cluster_{state.id}"
return pydot.Edge(
"i",
self.machine.initial_state.id,
initial_node.get_name(),
self._state_id(state),
label="",
color="blue",
fontname=self.font_name,
fontsize=self.transition_font_size,
**extra_params,
)

def _actions_getter(self):
Expand Down Expand Up @@ -104,11 +121,18 @@ def _state_actions(self, state):

return actions

@staticmethod
def _state_id(state):
if state.states:
return f"{state.id}_anchor"
else:
return state.id

def _state_as_node(self, state):
actions = self._state_actions(state)

node = pydot.Node(
state.id,
self._state_id(state),
label=f"{state.name}{actions}",
shape="rectangle",
style="rounded, filled",
Expand All @@ -127,29 +151,64 @@ def _transition_as_edge(self, transition):
cond = ", ".join([str(cond) for cond in transition.cond])
if cond:
cond = f"\n[{cond}]"

extra_params = {}
has_substates = transition.source.states or transition.target.states
if transition.source.states:
extra_params["ltail"] = f"cluster_{transition.source.id}"
if transition.target.states:
extra_params["lhead"] = f"cluster_{transition.target.id}"

return pydot.Edge(
transition.source.id,
transition.target.id,
self._state_id(transition.source),
self._state_id(transition.target),
label=f"{transition.event}{cond}",
color="blue",
fontname=self.font_name,
fontsize=self.transition_font_size,
minlen=2 if has_substates else 1,
**extra_params,
)

def get_graph(self):
graph = self._get_graph()
graph.add_node(self._initial_node())
graph.add_edge(self._initial_edge())
graph = self._get_graph(self.machine)
self._graph_states(self.machine, graph)
return graph

for state in self.machine.states:
graph.add_node(self._state_as_node(state))
for transition in state.transitions:
def _graph_states(self, state, graph):
initial_node = self._initial_node(state)
initial_subgraph = pydot.Subgraph(
graph_name=f"{initial_node.get_name()}_initial",
label="",
peripheries=0,
margin=0,
)
atomic_states_subgraph = pydot.Subgraph(
graph_name=f"cluster_{initial_node.get_name()}_atomic",
label="",
peripheries=0,
cluster="true",
)
initial_subgraph.add_node(initial_node)
graph.add_subgraph(initial_subgraph)
graph.add_subgraph(atomic_states_subgraph)

initial = next(s for s in state.states if s.initial)
graph.add_edge(self._initial_edge(initial_node, initial))

for substate in state.states:
if substate.states:
subgraph = self._get_subgraph(substate)
self._graph_states(substate, subgraph)
graph.add_subgraph(subgraph)
else:
atomic_states_subgraph.add_node(self._state_as_node(substate))

for transition in substate.transitions:
if transition.internal:
continue
graph.add_edge(self._transition_as_edge(transition))

return graph

def __call__(self):
return self.get_graph()

Expand All @@ -165,7 +224,8 @@ def quickchart_write_svg(sm: StateMachine, path: str):
>>> from tests.examples.order_control_machine import OrderControl
>>> sm = OrderControl()
>>> print(sm._graph().to_string())
digraph list {
digraph OrderControl {
compound=true;
fontname=Arial;
fontsize="10pt";
label=OrderControl;
Expand Down
22 changes: 7 additions & 15 deletions statemachine/engines/async_.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,17 @@
from threading import Lock
from typing import TYPE_CHECKING
from weakref import proxy

from ..event_data import EventData
from ..event_data import TriggerData
from ..exceptions import InvalidDefinition
from ..exceptions import TransitionNotAllowed
from ..i18n import _
from ..transition import Transition

if TYPE_CHECKING:
from ..statemachine import StateMachine
from .base import BaseEngine


class AsyncEngine:
def __init__(self, sm: "StateMachine", rtc: bool = True):
self.sm = proxy(sm)
self._sentinel = object()
class AsyncEngine(BaseEngine):
def __init__(self, sm, rtc: bool = True):
if not rtc:
raise InvalidDefinition(_("Only RTC is supported on async engine"))
self._processing = Lock()
super().__init__(sm, rtc)

async def activate_initial_state(self):
"""
Expand Down Expand Up @@ -63,16 +55,16 @@ async def processing_loop(self):
first_result = self._sentinel
try:
# Execute the triggers in the queue in FIFO order until the queue is empty
while self.sm._external_queue:
trigger_data = self.sm._external_queue.popleft()
while self._external_queue:
trigger_data = self._external_queue.popleft()
try:
result = await self._trigger(trigger_data)
if first_result is self._sentinel:
first_result = result
except Exception:
# Whe clear the queue as we don't have an expected behavior
# and cannot keep processing
self.sm._external_queue.clear()
self._external_queue.clear()
raise
finally:
self._processing.release()
Expand Down
34 changes: 34 additions & 0 deletions statemachine/engines/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections import deque
from threading import Lock
from typing import TYPE_CHECKING
from weakref import proxy

from ..event_data import TriggerData

if TYPE_CHECKING:
from ..statemachine import StateMachine


class BaseEngine:
def __init__(self, sm: "StateMachine", rtc: bool = True) -> None:
self.sm = proxy(sm)
self._external_queue: deque = deque()
self._sentinel = object()
self._rtc = rtc
self._processing = Lock()
self._put_initial_activation_trigger_on_queue()

def _put_nonblocking(self, trigger_data: TriggerData):
"""Put the trigger on the queue without blocking the caller."""
self._external_queue.append(trigger_data)

def _put_initial_activation_trigger_on_queue(self):
# Activate the initial state, this only works if the outer scope is sync code.
# for async code, the user should manually call `await sm.activate_initial_state()`
# after state machine creation.
if self.sm.current_state_value is None:
trigger_data = TriggerData(
machine=self.sm,
event="__initial__",
)
self._put_nonblocking(trigger_data)
Loading
Loading