Skip to content

Commit ecc7957

Browse files
committed
feat: Basic support for SCXML test suit
1 parent 9b55852 commit ecc7957

15 files changed

+532
-25
lines changed

conftest.py

+28
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,31 @@ def pytest_ignore_collect(collection_path, path, config):
3131

3232
if "django_project" in str(path):
3333
return True
34+
35+
36+
# @pytest.fixture(autouse=True, scope="module")
37+
# def mock_dot_write(request):
38+
# """
39+
# This fixture avoids updating files while executing tests
40+
# """
41+
42+
# def open_effect(
43+
# filename,
44+
# mode="r",
45+
# *args,
46+
# **kwargs,
47+
# ):
48+
# if mode in ("r", "rt", "rb"):
49+
# return open(filename, mode, *args, **kwargs)
50+
# elif filename.startswith("/tmp/"):
51+
# return open(filename, mode, *args, **kwargs)
52+
# elif "b" in mode:
53+
# return io.BytesIO()
54+
# else:
55+
# return io.StringIO()
56+
57+
# # using global mock instead of the fixture mocker due to the ScopeMismatch
58+
# # this fixture is module scoped and mocker is function scoped
59+
# with mock.patch("pydot.core.io.open", spec=True) as m:
60+
# m.side_effect = open_effect
61+
# yield m

pyproject.toml

+17-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dev = [
4242
"pytest-mock >=3.10.0",
4343
"pytest-benchmark >=4.0.0",
4444
"pytest-asyncio",
45+
"pydot",
4546
"django >=5.0.8; python_version >='3.10'",
4647
"pytest-django >=4.8.0; python_version >'3.8'",
4748
"Sphinx; python_version >'3.8'",
@@ -51,6 +52,7 @@ dev = [
5152
"sphinx-autobuild; python_version >'3.8'",
5253
"furo >=2024.5.6; python_version >'3.8'",
5354
"sphinx-copybutton >=0.5.2; python_version >'3.8'",
55+
"pdbr>=0.8.9; python_version >='3.8'",
5456
]
5557

5658
[build-system]
@@ -61,7 +63,21 @@ build-backend = "hatchling.build"
6163
packages = ["statemachine/"]
6264

6365
[tool.pytest.ini_options]
64-
addopts = "--ignore=docs/conf.py --ignore=docs/auto_examples/ --ignore=docs/_build/ --ignore=tests/examples/ --cov --cov-config .coveragerc --doctest-glob='*.md' --doctest-modules --doctest-continue-on-failure --benchmark-autosave --benchmark-group-by=name"
66+
addopts = [
67+
"--ignore=docs/conf.py",
68+
"--ignore=docs/auto_examples/",
69+
"--ignore=docs/_build/",
70+
"--ignore=tests/examples/",
71+
"--cov",
72+
"--cov-config",
73+
".coveragerc",
74+
"--doctest-glob=*.md",
75+
"--doctest-modules",
76+
"--doctest-continue-on-failure",
77+
"--benchmark-autosave",
78+
"--benchmark-group-by=name",
79+
"--pdbcls=pdbr:RichPdb",
80+
]
6581
doctest_optionflags = "ELLIPSIS IGNORE_EXCEPTION_DETAIL NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL"
6682
asyncio_mode = "auto"
6783
markers = ["""slow: marks tests as slow (deselect with '-m "not slow"')"""]

statemachine/callbacks.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import IntEnum
66
from enum import IntFlag
77
from enum import auto
8+
from functools import partial
89
from inspect import isawaitable
910
from typing import TYPE_CHECKING
1011
from typing import Callable
@@ -89,10 +90,10 @@ def __init__(
8990
self.attr_name: str = func and func.fget and func.fget.__name__ or ""
9091
elif callable(func):
9192
self.reference = SpecReference.CALLABLE
92-
self.is_bounded = hasattr(func, "__self__")
93-
self.attr_name = (
94-
func.__name__ if not self.is_event or self.is_bounded else f"_{func.__name__}_"
95-
)
93+
is_partial = isinstance(func, partial)
94+
self.is_bounded = is_partial or hasattr(func, "__self__")
95+
name = func.func.__name__ if is_partial else func.__name__
96+
self.attr_name = name if not self.is_event or self.is_bounded else f"_{name}_"
9697
if not self.is_bounded:
9798
func.attr_name = self.attr_name
9899
func.is_event = is_event
@@ -110,7 +111,7 @@ def __repr__(self):
110111
return f"{type(self).__name__}({self.func!r}, is_convention={self.is_convention!r})"
111112

112113
def __str__(self):
113-
name = getattr(self.func, "__name__", self.func)
114+
name = self.attr_name
114115
if self.expected_value is False:
115116
name = f"!{name}"
116117
return name

statemachine/dispatcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _search_callable(self, spec):
166166
yield listener.build_key(spec.attr_name), partial(callable_method, func)
167167
return
168168

169-
yield f"{spec.attr_name}@None", partial(callable_method, spec.func)
169+
yield f"{spec.attr_name}-{id(spec.func)}@None", partial(callable_method, spec.func)
170170

171171
def search_name(self, name):
172172
for listener in self.items:

statemachine/io/__init__.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Dict
2+
3+
from ..factory import StateMachineMetaclass
4+
from ..state import State
5+
from ..statemachine import StateMachine
6+
from ..transition_list import TransitionList
7+
8+
9+
def create_machine_class_from_definition(
10+
name: str, definition: dict, **extra_kwargs
11+
) -> StateMachine:
12+
"""
13+
Creates a StateMachine class from a dictionary definition, using the StateMachineMetaclass.
14+
15+
Example usage with a traffic light machine:
16+
17+
>>> machine = create_machine_class_from_definition(
18+
... "TrafficLightMachine",
19+
... {
20+
... "states": {
21+
... "green": {"initial": True},
22+
... "yellow": {},
23+
... "red": {},
24+
... },
25+
... "events": {
26+
... "change": [
27+
... {"from": "green", "to": "yellow"},
28+
... {"from": "yellow", "to": "red"},
29+
... {"from": "red", "to": "green"},
30+
... ]
31+
... },
32+
... }
33+
... )
34+
35+
"""
36+
37+
states_instances = {
38+
state_id: State(**state_kwargs) for state_id, state_kwargs in definition["states"].items()
39+
}
40+
41+
events: Dict[str, TransitionList] = {}
42+
for event_name, transitions in definition["events"].items():
43+
for transition_data in transitions:
44+
source = states_instances[transition_data["from"]]
45+
target = states_instances[transition_data["to"]]
46+
47+
transition = source.to(
48+
target,
49+
event=event_name,
50+
cond=transition_data.get("cond"),
51+
unless=transition_data.get("unless"),
52+
on=transition_data.get("on"),
53+
before=transition_data.get("before"),
54+
after=transition_data.get("after"),
55+
)
56+
57+
if event_name in events:
58+
events[event_name] |= transition
59+
elif event_name is not None:
60+
events[event_name] = transition
61+
62+
attrs_mapper = {**extra_kwargs, **states_instances, **events}
63+
64+
return StateMachineMetaclass(name, (StateMachine,), attrs_mapper) # type: ignore[return-value]

statemachine/io/scxml.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
Simple SCXML parser that converts SCXML documents to state machine definitions.
3+
"""
4+
5+
import xml.etree.ElementTree as ET
6+
from functools import partial
7+
from typing import Any
8+
from typing import Dict
9+
from typing import List
10+
11+
from statemachine.statemachine import StateMachine
12+
13+
14+
def send_event(machine: StateMachine, event_to_send: str) -> None:
15+
machine.send(event_to_send)
16+
17+
18+
def assign(model, location, expr):
19+
pass
20+
21+
22+
def strip_namespaces(tree):
23+
"""Remove all namespaces from tags and attributes in place.
24+
25+
Leaves only the local names in the subtree.
26+
"""
27+
for el in tree.iter():
28+
tag = el.tag
29+
if tag and isinstance(tag, str) and tag[0] == "{":
30+
el.tag = tag.partition("}")[2]
31+
attrib = el.attrib
32+
if attrib:
33+
for name, value in list(attrib.items()):
34+
if name and isinstance(name, str) and name[0] == "{":
35+
del attrib[name]
36+
attrib[name.partition("}")[2]] = value
37+
38+
39+
def parse_scxml(scxml_content: str) -> Dict[str, Any]: # noqa: C901
40+
"""
41+
Parse SCXML content and return a dictionary definition compatible with
42+
create_machine_class_from_definition.
43+
44+
The returned dictionary has the format:
45+
{
46+
"states": {
47+
"state_id": {"initial": True},
48+
...
49+
},
50+
"events": {
51+
"event_name": [
52+
{"from": "source_state", "to": "target_state"},
53+
...
54+
]
55+
}
56+
}
57+
"""
58+
# Parse XML content
59+
root = ET.fromstring(scxml_content)
60+
strip_namespaces(root)
61+
62+
# Find the scxml element (it might be the root or a child)
63+
scxml = root if "scxml" in root.tag else root.find(".//scxml")
64+
if scxml is None:
65+
raise ValueError("No scxml element found in document")
66+
67+
# Get initial state from scxml element
68+
initial_state = scxml.get("initial")
69+
70+
# Build states dictionary
71+
states = {}
72+
events: Dict[str, List[Dict[str, str]]] = {}
73+
74+
def _parse_state(state_elem, final=False): # noqa: C901
75+
state_id = state_elem.get("id")
76+
if not state_id:
77+
raise ValueError("All states must have an id")
78+
79+
# Mark as initial if specified
80+
states[state_id] = {"initial": state_id == initial_state, "final": final}
81+
82+
# Process transitions
83+
for trans_elem in state_elem.findall("transition"):
84+
event = trans_elem.get("event") or None
85+
target = trans_elem.get("target")
86+
87+
if target:
88+
if event not in events:
89+
events[event] = []
90+
91+
if target not in states:
92+
states[target] = {}
93+
94+
events[event].append(
95+
{
96+
"from": state_id,
97+
"to": target,
98+
}
99+
)
100+
101+
for onentry_elem in state_elem.findall("onentry"):
102+
for raise_elem in onentry_elem.findall("raise"):
103+
event = raise_elem.get("event")
104+
if event:
105+
state = states[state_id]
106+
if "enter" not in state:
107+
state["enter"] = []
108+
state["enter"].append(partial(send_event, event_to_send=event))
109+
110+
# First pass: collect all states and mark initial
111+
for state_elem in scxml.findall(".//state"):
112+
_parse_state(state_elem)
113+
114+
# Second pass: collect final states
115+
for state_elem in scxml.findall(".//final"):
116+
_parse_state(state_elem, final=True)
117+
118+
# If no initial state was specified, mark the first state as initial
119+
if not initial_state and states:
120+
first_state = next(iter(states))
121+
states[first_state]["initial"] = True
122+
123+
return {
124+
"states": states,
125+
"events": events,
126+
}

statemachine/spec_parser.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import operator
23
import re
34
from typing import Callable
45

@@ -40,6 +41,23 @@ def decorated(*args, **kwargs) -> bool:
4041
return decorated
4142

4243

44+
def build_custom_operator(operator) -> Callable:
45+
def custom_comparator(left: Callable, right: Callable) -> Callable:
46+
def decorated(*args, **kwargs) -> bool:
47+
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))
48+
49+
return decorated
50+
51+
return custom_comparator
52+
53+
54+
def build_constant(constant) -> Callable:
55+
def decorated(*args, **kwargs):
56+
return constant
57+
58+
return decorated
59+
60+
4361
def custom_or(left: Callable, right: Callable) -> Callable:
4462
def decorated(*args, **kwargs) -> bool:
4563
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
@@ -49,7 +67,7 @@ def decorated(*args, **kwargs) -> bool:
4967
return decorated
5068

5169

52-
def build_expression(node, variable_hook, operator_mapping):
70+
def build_expression(node, variable_hook, operator_mapping): # noqa: C901
5371
if isinstance(node, ast.BoolOp):
5472
# Handle `and` / `or` operations
5573
operator_fn = operator_mapping[type(node.op)]
@@ -58,13 +76,23 @@ def build_expression(node, variable_hook, operator_mapping):
5876
right_expr = build_expression(right, variable_hook, operator_mapping)
5977
left_expr = operator_fn(left_expr, right_expr)
6078
return left_expr
79+
elif isinstance(node, ast.Compare):
80+
operator_fn = operator_mapping[type(node.ops[0])]
81+
left_expr = build_expression(node.left, variable_hook, operator_mapping)
82+
for right in node.comparators:
83+
right_expr = build_expression(right, variable_hook, operator_mapping)
84+
left_expr = operator_fn(left_expr, right_expr)
85+
return left_expr
6186
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
6287
# Handle `not` operation
6388
operand_expr = build_expression(node.operand, variable_hook, operator_mapping)
6489
return operator_mapping[type(node.op)](operand_expr)
6590
elif isinstance(node, ast.Name):
6691
# Handle variables by calling the variable_hook
6792
return variable_hook(node.id)
93+
elif isinstance(node, ast.Constant):
94+
# Handle constants by returning the value
95+
return build_constant(node.value)
6896
else:
6997
raise ValueError(f"Unsupported expression structure: {node.__class__.__name__}")
7098

@@ -80,4 +108,14 @@ def parse_boolean_expr(expr, variable_hook, operator_mapping):
80108
return build_expression(tree.body, variable_hook, operator_mapping)
81109

82110

83-
operator_mapping = {ast.Or: custom_or, ast.And: custom_and, ast.Not: custom_not}
111+
operator_mapping = {
112+
ast.Or: custom_or,
113+
ast.And: custom_and,
114+
ast.Not: custom_not,
115+
ast.GtE: build_custom_operator(operator.ge),
116+
ast.Gt: build_custom_operator(operator.gt),
117+
ast.LtE: build_custom_operator(operator.le),
118+
ast.Lt: build_custom_operator(operator.lt),
119+
ast.Eq: build_custom_operator(operator.eq),
120+
ast.NotEq: build_custom_operator(operator.ne),
121+
}

0 commit comments

Comments
 (0)