Skip to content

Commit 2c427cd

Browse files
committed
feat: Diagram compound and parallel states
1 parent c765228 commit 2c427cd

12 files changed

+192
-41
lines changed

conftest.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import io
12
import sys
3+
from unittest import mock
24

35
import pytest
46

@@ -31,3 +33,25 @@ def pytest_ignore_collect(collection_path, path, config):
3133

3234
if "django_project" in str(path):
3335
return True
36+
37+
38+
@pytest.fixture(autouse=True, scope="module")
39+
def mock_dot_write(request):
40+
def open_effect(
41+
filename,
42+
mode="r",
43+
*args,
44+
**kwargs,
45+
):
46+
if mode in ("r", "rt", "rb"):
47+
return open(filename, mode, *args, **kwargs)
48+
elif filename.startswith("/tmp/"):
49+
return open(filename, mode, *args, **kwargs)
50+
elif "b" in mode:
51+
return io.BytesIO()
52+
else:
53+
return io.StringIO()
54+
55+
with mock.patch("pydot.core.io.open", spec=True) as m:
56+
m.side_effect = open_effect
57+
yield m

docs/diagram.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Graphviz. For example, on Debian-based systems (such as Ubuntu), you can use the
4242
>>> dot = graph()
4343

4444
>>> dot.to_string() # doctest: +ELLIPSIS
45-
'digraph list {...
45+
'digraph OrderControl {...
4646

4747
```
4848

1.94 KB
Loading
1.22 KB
Loading
2.51 KB
Loading
438 Bytes
Loading

statemachine/contrib/diagram.py

+81-21
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,34 @@ class DotGraphMachine:
3232
def __init__(self, machine):
3333
self.machine = machine
3434

35-
def _get_graph(self):
36-
machine = self.machine
35+
def _get_graph(self, machine):
3736
return pydot.Dot(
38-
"list",
37+
machine.name,
3938
graph_type="digraph",
4039
label=machine.name,
4140
fontname=self.font_name,
4241
fontsize=self.state_font_size,
4342
rankdir=self.graph_rankdir,
43+
compound="true",
4444
)
4545

46-
def _initial_node(self):
46+
def _get_subgraph(self, state):
47+
style = ", solid"
48+
if state.parent and state.parent.parallel:
49+
style = ", dashed"
50+
subgraph = pydot.Subgraph(
51+
label=f"{state.name}",
52+
graph_name=f"cluster_{state.id}",
53+
style=f"rounded{style}",
54+
cluster="true",
55+
)
56+
return subgraph
57+
58+
def _initial_node(self, state):
4759
node = pydot.Node(
48-
"i",
49-
shape="circle",
60+
self._state_id(state),
61+
label="",
62+
shape="point",
5063
style="filled",
5164
fontsize="1pt",
5265
fixedsize="true",
@@ -56,14 +69,18 @@ def _initial_node(self):
5669
node.set_fillcolor("black")
5770
return node
5871

59-
def _initial_edge(self):
72+
def _initial_edge(self, initial_node, state):
73+
extra_params = {}
74+
if state.states:
75+
extra_params["lhead"] = f"cluster_{state.id}"
6076
return pydot.Edge(
61-
"i",
62-
self.machine.initial_state.id,
77+
initial_node.get_name(),
78+
self._state_id(state),
6379
label="",
6480
color="blue",
6581
fontname=self.font_name,
6682
fontsize=self.transition_font_size,
83+
**extra_params,
6784
)
6885

6986
def _actions_getter(self):
@@ -104,11 +121,18 @@ def _state_actions(self, state):
104121

105122
return actions
106123

124+
@staticmethod
125+
def _state_id(state):
126+
if state.states:
127+
return f"{state.id}_anchor"
128+
else:
129+
return state.id
130+
107131
def _state_as_node(self, state):
108132
actions = self._state_actions(state)
109133

110134
node = pydot.Node(
111-
state.id,
135+
self._state_id(state),
112136
label=f"{state.name}{actions}",
113137
shape="rectangle",
114138
style="rounded, filled",
@@ -127,29 +151,64 @@ def _transition_as_edge(self, transition):
127151
cond = ", ".join([str(cond) for cond in transition.cond])
128152
if cond:
129153
cond = f"\n[{cond}]"
154+
155+
extra_params = {}
156+
has_substates = transition.source.states or transition.target.states
157+
if transition.source.states:
158+
extra_params["ltail"] = f"cluster_{transition.source.id}"
159+
if transition.target.states:
160+
extra_params["lhead"] = f"cluster_{transition.target.id}"
161+
130162
return pydot.Edge(
131-
transition.source.id,
132-
transition.target.id,
163+
self._state_id(transition.source),
164+
self._state_id(transition.target),
133165
label=f"{transition.event}{cond}",
134166
color="blue",
135167
fontname=self.font_name,
136168
fontsize=self.transition_font_size,
169+
minlen=2 if has_substates else 1,
170+
**extra_params,
137171
)
138172

139173
def get_graph(self):
140-
graph = self._get_graph()
141-
graph.add_node(self._initial_node())
142-
graph.add_edge(self._initial_edge())
174+
graph = self._get_graph(self.machine)
175+
self._graph_states(self.machine, graph)
176+
return graph
143177

144-
for state in self.machine.states:
145-
graph.add_node(self._state_as_node(state))
146-
for transition in state.transitions:
178+
def _graph_states(self, state, graph):
179+
initial_node = self._initial_node(state)
180+
initial_subgraph = pydot.Subgraph(
181+
graph_name=f"{initial_node.get_name()}_initial",
182+
label="",
183+
peripheries=0,
184+
margin=0,
185+
)
186+
atomic_states_subgraph = pydot.Subgraph(
187+
graph_name=f"cluster_{initial_node.get_name()}_atomic",
188+
label="",
189+
peripheries=0,
190+
cluster="true",
191+
)
192+
initial_subgraph.add_node(initial_node)
193+
graph.add_subgraph(initial_subgraph)
194+
graph.add_subgraph(atomic_states_subgraph)
195+
196+
initial = next(s for s in state.states if s.initial)
197+
graph.add_edge(self._initial_edge(initial_node, initial))
198+
199+
for substate in state.states:
200+
if substate.states:
201+
subgraph = self._get_subgraph(substate)
202+
self._graph_states(substate, subgraph)
203+
graph.add_subgraph(subgraph)
204+
else:
205+
atomic_states_subgraph.add_node(self._state_as_node(substate))
206+
207+
for transition in substate.transitions:
147208
if transition.internal:
148209
continue
149210
graph.add_edge(self._transition_as_edge(transition))
150211

151-
return graph
152-
153212
def __call__(self):
154213
return self.get_graph()
155214

@@ -165,7 +224,8 @@ def quickchart_write_svg(sm: StateMachine, path: str):
165224
>>> from tests.examples.order_control_machine import OrderControl
166225
>>> sm = OrderControl()
167226
>>> print(sm._graph().to_string())
168-
digraph list {
227+
digraph OrderControl {
228+
compound=true;
169229
fontname=Arial;
170230
fontsize="10pt";
171231
label=OrderControl;

statemachine/factory.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
super().__init__(name, bases, attrs)
3333
registry.register(cls)
3434
cls.name = cls.__name__
35+
cls.id = cls.name.lower()
3536
cls.states: States = States()
3637
cls.states_map: Dict[Any, State] = {}
3738
"""Map of ``state.value`` to the corresponding :ref:`state`."""
@@ -46,6 +47,9 @@ def __init__(
4647

4748
if not cls.states:
4849
return
50+
51+
cls._initials_by_document_order(cls.states)
52+
4953
initials = [s for s in cls.states if s.initial]
5054
parallels = [s.id for s in cls.states if s.parallel]
5155
root_only_has_parallels = len(cls.states) == len(parallels)
@@ -59,11 +63,7 @@ def __init__(
5963
)
6064

6165
try:
62-
if root_only_has_parallels:
63-
# TODO: Temp, whe should fix initial, and current state design
64-
cls.initial_state: State = next(s for s in cls.states if s.initial)
65-
else:
66-
cls.initial_state: State = next(s for s in initials if s.initial)
66+
cls.initial_state: State = next(s for s in initials if s.initial)
6767
except StopIteration:
6868
cls.initial_state = None # Abstract SM still don't have states
6969

@@ -77,6 +77,16 @@ def __init__(
7777

7878
def __getattr__(self, attribute: str) -> Any: ...
7979

80+
def _initials_by_document_order(cls, states):
81+
has_initial = False
82+
for s in states:
83+
cls._initials_by_document_order(s.states)
84+
if s.initial:
85+
has_initial = True
86+
break
87+
if not has_initial and states:
88+
states[0]._initial = True
89+
8090
def _check(cls):
8191
has_states = bool(cls.states)
8292
has_events = bool(cls._events)
@@ -233,7 +243,7 @@ def add_state(cls, id, state: State):
233243
for event in state.transitions.unique_events:
234244
cls.add_event(event)
235245

236-
for substate in state.substates:
246+
for substate in state.states:
237247
cls.add_state(substate.id, substate)
238248

239249
def add_event(cls, event, transitions=None):

statemachine/state.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def __new__( # type: ignore [misc]
2222
if not bases:
2323
return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
2424

25-
substates = []
25+
states = []
2626
for key, value in attrs.items():
2727
if isinstance(value, State):
2828
value._set_id(key)
29-
substates.append(value)
29+
states.append(value)
3030
if isinstance(value, TransitionList):
3131
value.add_event(key)
3232

33-
return State(name=name, substates=substates, **kwargs)
33+
return State(name=name, states=states, **kwargs)
3434

3535

3636
class State:
@@ -141,14 +141,15 @@ def __init__(
141141
initial: bool = False,
142142
final: bool = False,
143143
parallel: bool = False,
144-
substates: Any = None,
144+
states: Any = None,
145145
enter: Any = None,
146146
exit: Any = None,
147147
):
148148
self.name = name
149149
self.value = value
150150
self.parallel = parallel
151-
self.substates = substates or []
151+
self.states = states or []
152+
self.is_atomic = bool(not self.states)
152153
self._initial = initial
153154
self._final = final
154155
self._id: str = ""
@@ -161,12 +162,12 @@ def __init__(
161162
self.exit = self._specs.grouper(CallbackGroup.EXIT).add(
162163
exit, priority=CallbackPriority.INLINE
163164
)
164-
self._init_substates()
165+
self._init_states()
165166

166-
def _init_substates(self):
167-
for substate in self.substates:
168-
substate.parent = self
169-
setattr(self, substate.id, substate)
167+
def _init_states(self):
168+
for state in self.states:
169+
state.parent = self
170+
setattr(self, state.id, state)
170171

171172
def __eq__(self, other):
172173
return isinstance(other, State) and self.name == other.name and self.id == other.id
@@ -268,6 +269,7 @@ def __init__(
268269
):
269270
self._state = ref(state)
270271
self._machine = ref(machine)
272+
self._init_states()
271273

272274
@property
273275
def name(self):
@@ -313,3 +315,15 @@ def id(self) -> str:
313315
@property
314316
def is_active(self):
315317
return self._machine().current_state == self
318+
319+
@property
320+
def is_atomic(self):
321+
return self._state().is_atomic
322+
323+
@property
324+
def parent(self):
325+
return self._state().parent
326+
327+
@property
328+
def states(self):
329+
return self._state().states

statemachine/states.py

+3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def __getattr__(self, name: str):
6161
def __len__(self):
6262
return len(self._states)
6363

64+
def __getitem__(self, index):
65+
return list(self)[index]
66+
6467
def __iter__(self):
6568
return iter(self._states.values())
6669

0 commit comments

Comments
 (0)