Skip to content

Commit c765228

Browse files
committed
feat: Nested states (compound / parallel)
1 parent aeae747 commit c765228

File tree

6 files changed

+283
-5
lines changed

6 files changed

+283
-5
lines changed

statemachine/factory.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,26 @@ def __init__(
4444
cls.add_inherited(bases)
4545
cls.add_from_attributes(attrs)
4646

47+
if not cls.states:
48+
return
49+
initials = [s for s in cls.states if s.initial]
50+
parallels = [s.id for s in cls.states if s.parallel]
51+
root_only_has_parallels = len(cls.states) == len(parallels)
52+
53+
if len(initials) != 1 and not root_only_has_parallels:
54+
raise InvalidDefinition(
55+
_(
56+
"There should be one and only one initial state. "
57+
"Your currently have these: {0}"
58+
).format(", ".join(s.id for s in initials))
59+
)
60+
4761
try:
48-
cls.initial_state: State = next(s for s in cls.states if s.initial)
62+
if root_only_has_parallels:
63+
# TODO: Temp, whe should fix initial, and current state design
64+
cls.initial_state: State = next(s for s in cls.states if s.initial)
65+
else:
66+
cls.initial_state: State = next(s for s in initials if s.initial)
4967
except StopIteration:
5068
cls.initial_state = None # Abstract SM still don't have states
5169

@@ -205,15 +223,19 @@ def _add_unbounded_callback(cls, attr_name, func):
205223

206224
def add_state(cls, id, state: State):
207225
state._set_id(id)
208-
cls.states.append(state)
209-
cls.states_map[state.value] = state
210-
if not hasattr(cls, id):
211-
setattr(cls, id, state)
226+
if not state.parent:
227+
cls.states.append(state)
228+
cls.states_map[state.value] = state
229+
if not hasattr(cls, id):
230+
setattr(cls, id, state)
212231

213232
# also register all events associated directly with transitions
214233
for event in state.transitions.unique_events:
215234
cls.add_event(event)
216235

236+
for substate in state.substates:
237+
cls.add_state(substate.id, substate)
238+
217239
def add_event(cls, event, transitions=None):
218240
if transitions is not None:
219241
transitions.add_event(event)

statemachine/state.py

+51
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@
1515
from .statemachine import StateMachine
1616

1717

18+
class NestedStateFactory(type):
19+
def __new__( # type: ignore [misc]
20+
cls, classname, bases, attrs, name=None, **kwargs
21+
) -> "State":
22+
if not bases:
23+
return super().__new__(cls, classname, bases, attrs) # type: ignore [return-value]
24+
25+
substates = []
26+
for key, value in attrs.items():
27+
if isinstance(value, State):
28+
value._set_id(key)
29+
substates.append(value)
30+
if isinstance(value, TransitionList):
31+
value.add_event(key)
32+
33+
return State(name=name, substates=substates, **kwargs)
34+
35+
1836
class State:
1937
"""
2038
A State in a :ref:`StateMachine` describes a particular behavior of the machine.
@@ -94,20 +112,47 @@ class State:
94112
95113
"""
96114

115+
class Builder(metaclass=NestedStateFactory):
116+
# Mimic the :ref:`State` public API to help linters discover the result of the Builder
117+
# class.
118+
119+
@classmethod
120+
def to(cls, *args: "State", **kwargs) -> "TransitionList": # pragma: no cover
121+
"""Create transitions to the given target states.
122+
123+
.. note: This method is only a type hint for mypy.
124+
The actual implementation belongs to the :ref:`State` class.
125+
"""
126+
return TransitionList()
127+
128+
@classmethod
129+
def from_(cls, *args: "State", **kwargs) -> "TransitionList": # pragma: no cover
130+
"""Create transitions from the given target states (reversed).
131+
132+
.. note: This method is only a type hint for mypy.
133+
The actual implementation belongs to the :ref:`State` class.
134+
"""
135+
return TransitionList()
136+
97137
def __init__(
98138
self,
99139
name: str = "",
100140
value: Any = None,
101141
initial: bool = False,
102142
final: bool = False,
143+
parallel: bool = False,
144+
substates: Any = None,
103145
enter: Any = None,
104146
exit: Any = None,
105147
):
106148
self.name = name
107149
self.value = value
150+
self.parallel = parallel
151+
self.substates = substates or []
108152
self._initial = initial
109153
self._final = final
110154
self._id: str = ""
155+
self.parent: "State" = None
111156
self.transitions = TransitionList()
112157
self._specs = CallbackSpecList()
113158
self.enter = self._specs.grouper(CallbackGroup.ENTER).add(
@@ -116,6 +161,12 @@ def __init__(
116161
self.exit = self._specs.grouper(CallbackGroup.EXIT).add(
117162
exit, priority=CallbackPriority.INLINE
118163
)
164+
self._init_substates()
165+
166+
def _init_substates(self):
167+
for substate in self.substates:
168+
substate.parent = self
169+
setattr(self, substate.id, substate)
119170

120171
def __eq__(self, other):
121172
return isinstance(other, State) and self.name == other.name and self.id == other.id
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Microwave machine
3+
=================
4+
5+
Example that exercises the Compound and Parallel states.
6+
7+
Compound
8+
--------
9+
10+
If there are more than one substates, one of them is usually designated as the initial state of
11+
that compound state.
12+
13+
When a compound state is active, its substates behave as though they were an active state machine:
14+
Exactly one child state must also be active. This means that:
15+
16+
When a compound state is entered, it must also enter exactly one of its substates, usually its
17+
initial state.
18+
When an event happens, the substates have priority when it comes to selecting which transition to
19+
follow. If a substate happens to handles an event, the event is consumed, it isn’t passed to the
20+
parent compound state.
21+
When a substate transitions to another substate, both “inside” the compound state, the compound
22+
state does not exit or enter; it remains active.
23+
When a compound state exits, its substate is simultaneously exited too. (Technically, the substate
24+
exits first, then its parent.)
25+
Compound states may be nested, or include parallel states.
26+
27+
The opposite of a compound state is an atomic state, which is a state with no substates.
28+
29+
A compound state is allowed to define transitions to its child states. Normally, when a transition
30+
leads from a state, it causes that state to be exited. For transitions from a compound state to
31+
one of its descendants, it is possible to define a transition that avoids exiting and entering
32+
the compound state itself, such transitions are called local transitions.
33+
34+
35+
"""
36+
from statemachine import State
37+
from statemachine import StateMachine
38+
39+
40+
class MicroWave(StateMachine):
41+
class oven(State.Builder, name="Microwave oven", parallel=True):
42+
class engine(State.Builder):
43+
off = State("Off", initial=True)
44+
45+
class on(State.Builder):
46+
idle = State("Idle", initial=True)
47+
cooking = State("Cooking")
48+
49+
idle.to(cooking, cond="closed.is_active")
50+
cooking.to(idle, cond="open.is_active")
51+
cooking.to.itself(internal=True, on="increment_timer")
52+
53+
turn_off = on.to(off)
54+
turn_on = off.to(on)
55+
on.to(off, cond="cook_time_is_over") # eventless transition
56+
57+
class door(State.Builder):
58+
closed = State(initial=True)
59+
open = State()
60+
61+
door_open = closed.to(open)
62+
door_close = open.to(closed)
63+
64+
def __init__(self):
65+
self.cook_time = 5
66+
self.door_closed = True
67+
self.timer = 0
68+
super().__init__()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
Nested Traffic light machine
3+
----------------------------
4+
5+
Demonstrates the concept of nested compound states.
6+
7+
From this example on XState: https://xstate.js.org/docs/guides/hierarchical.html#api
8+
9+
"""
10+
import time
11+
12+
from statemachine import State
13+
from statemachine import StateMachine
14+
15+
16+
class NestedTrafficLightMachine(StateMachine):
17+
"A traffic light machine"
18+
green = State(initial=True, enter="reset_elapsed")
19+
yellow = State(enter="reset_elapsed")
20+
21+
class red(State.Builder, enter="reset_elapsed"):
22+
"Pedestrian states"
23+
walk = State(initial=True)
24+
wait = State()
25+
stop = State()
26+
blinking = State()
27+
28+
ped_countdown = walk.to(wait) | wait.to(stop)
29+
30+
timer = green.to(yellow) | yellow.to(red) | red.to(green)
31+
power_outage = red.blinking.from_()
32+
power_restored = red.from_()
33+
34+
def __init__(self, seconds_to_turn_state=5, seconds_running=20):
35+
self.seconds_to_turn_state = seconds_to_turn_state
36+
self.seconds_running = seconds_running
37+
super().__init__(allow_event_without_transition=True)
38+
39+
def on_timer(self, event: str, source: State, target: State):
40+
print(f".. Running {event} from {source.id} to {target.id}")
41+
42+
def reset_elapsed(self, event: str, time: int = 0):
43+
print(f"entering reset_elapsed from {event} with {time}")
44+
self.last_turn = time
45+
46+
@timer.cond
47+
def time_is_over(self, time):
48+
return time - self.last_turn > self.seconds_to_turn_state
49+
50+
def run_forever(self):
51+
self.running = True
52+
start_time = time.time()
53+
while self.running:
54+
print("tick!")
55+
time.sleep(1)
56+
curr_time = time.time()
57+
self.send("timer", time=curr_time)
58+
59+
if curr_time - start_time > self.seconds_running:
60+
self.running = False
61+
62+
63+
sm = NestedTrafficLightMachine()
64+
sm.send("anything")

tests/test_compound.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
3+
from statemachine import State
4+
5+
6+
@pytest.fixture()
7+
def microwave_cls():
8+
from tests.examples.microwave_inheritance_machine import MicroWave
9+
10+
return MicroWave
11+
12+
13+
def assert_state(s, name, initial=False, final=False, parallel=False, substates=None):
14+
if substates is None:
15+
substates = []
16+
17+
assert isinstance(s, State)
18+
assert s.name == name
19+
assert s.initial is initial
20+
assert s.final is final
21+
assert s.parallel is parallel
22+
assert isinstance(s, State)
23+
assert set(s.substates) == set(substates)
24+
25+
26+
class TestNestedSyntax:
27+
def test_capture_constructor_arguments(self, microwave_cls):
28+
sm = microwave_cls()
29+
30+
assert_state(
31+
sm.oven,
32+
"Microwave oven",
33+
parallel=True,
34+
substates=[sm.oven.engine, sm.oven.door],
35+
)
36+
assert_state(
37+
sm.oven.engine,
38+
"Engine",
39+
initial=False,
40+
substates=[sm.oven.engine.on, sm.oven.engine.off],
41+
)
42+
assert_state(sm.oven.engine.off, "Off", initial=True)
43+
assert_state(
44+
sm.oven.engine.on,
45+
"On",
46+
substates=[sm.oven.engine.on.idle, sm.oven.engine.on.cooking],
47+
)
48+
assert_state(
49+
sm.oven.door,
50+
"Door",
51+
initial=False,
52+
substates=[sm.oven.door.closed, sm.oven.door.open],
53+
)
54+
assert_state(sm.oven.door.closed, "Closed", initial=True)
55+
assert_state(sm.oven.door.open, "Open")
56+
57+
def test_list_children_states(self, microwave_cls):
58+
sm = microwave_cls()
59+
assert [s.id for s in sm.oven.engine.substates] == ["off", "on"]
60+
61+
def test_list_events(self, microwave_cls):
62+
sm = microwave_cls()
63+
assert [e.name for e in sm.events] == [
64+
"turn_on",
65+
"turn_off",
66+
"door_open",
67+
"door_close",
68+
]

tests/test_nested.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def test_nested_sm():
2+
from tests.examples.microwave_inheritance_machine import MicroWave
3+
4+
sm = MicroWave()
5+
assert sm.current_state.id == "oven"

0 commit comments

Comments
 (0)