@@ -75,10 +75,15 @@ def __init__(self, sm: "StateMachine"):
75
75
self ._sentinel = object ()
76
76
self .running = True
77
77
self ._processing = Lock ()
78
+ self ._cache : Dict = {} # Cache for _get_args_kwargs results
78
79
79
80
def empty (self ):
80
81
return self .external_queue .is_empty ()
81
82
83
+ def clear_cache (self ):
84
+ """Clears the cache. Should be called at the start of each processing loop."""
85
+ self ._cache .clear ()
86
+
82
87
def put (self , trigger_data : TriggerData , internal : bool = False , _delayed : bool = False ):
83
88
"""Put the trigger on the queue without blocking the caller."""
84
89
if not self .running and not self .sm .allow_event_without_transition :
@@ -310,7 +315,13 @@ def microstep(self, transitions: List[Transition], trigger_data: TriggerData):
310
315
def _get_args_kwargs (
311
316
self , transition : Transition , trigger_data : TriggerData , target : "State | None" = None
312
317
):
313
- # TODO: Ideally this method should be called only once per microstep/transition
318
+ # Generate a unique key for the cache, the cache is invalidated once per loop
319
+ cache_key = (id (transition ), id (trigger_data ), id (target ))
320
+
321
+ # Check the cache for existing results
322
+ if cache_key in self ._cache :
323
+ return self ._cache [cache_key ]
324
+
314
325
event_data = EventData (trigger_data = trigger_data , transition = transition )
315
326
if target :
316
327
event_data .state = target
@@ -321,6 +332,9 @@ def _get_args_kwargs(
321
332
result = self .sm ._callbacks .call (self .sm .prepare .key , * args , ** kwargs )
322
333
for new_kwargs in result :
323
334
kwargs .update (new_kwargs )
335
+
336
+ # Store the result in the cache
337
+ self ._cache [cache_key ] = (args , kwargs )
324
338
return args , kwargs
325
339
326
340
def _conditions_match (self , transition : Transition , trigger_data : TriggerData ):
@@ -329,7 +343,9 @@ def _conditions_match(self, transition: Transition, trigger_data: TriggerData):
329
343
self .sm ._callbacks .call (transition .validators .key , * args , ** kwargs )
330
344
return self .sm ._callbacks .all (transition .cond .key , * args , ** kwargs )
331
345
332
- def _exit_states (self , enabled_transitions : List [Transition ], trigger_data : TriggerData ):
346
+ def _exit_states (
347
+ self , enabled_transitions : List [Transition ], trigger_data : TriggerData
348
+ ) -> OrderedSet [State ]:
333
349
"""Compute and process the states to exit for the given transitions."""
334
350
states_to_exit = self ._compute_exit_set (enabled_transitions )
335
351
@@ -340,7 +356,7 @@ def _exit_states(self, enabled_transitions: List[Transition], trigger_data: Trig
340
356
ordered_states = sorted (
341
357
states_to_exit , key = lambda x : x .source and x .source .document_order or 0 , reverse = True
342
358
)
343
- result = OrderedSet ([info .source for info in ordered_states ])
359
+ result = OrderedSet ([info .source for info in ordered_states if info . source ])
344
360
logger .debug ("States to exit: %s" , result )
345
361
346
362
for info in ordered_states :
0 commit comments