From 6f51ba19d08bcab372b42016a8e6182e36669021 Mon Sep 17 00:00:00 2001 From: Adam Bella Date: Mon, 30 Aug 2021 13:58:05 +0200 Subject: [PATCH 1/4] refactor lua reference tracking --- lupa/_lupa.pyx | 62 +++++++++++++++++++++++++++++++++++++++++++++- lupa/tests/test.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index 9af27175..f2f58c9f 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -242,6 +242,7 @@ cdef class LuaRuntime: cdef object _attribute_getter cdef object _attribute_setter cdef bint _unpack_returned_tuples + cdef StatesReferenceTracker _ref_tracker def __cinit__(self, encoding='UTF-8', source_encoding=None, attribute_filter=None, attribute_handlers=None, @@ -255,6 +256,8 @@ cdef class LuaRuntime: self._pyrefs_in_lua = {} self._encoding = _asciiOrNone(encoding) self._source_encoding = _asciiOrNone(source_encoding) or self._encoding or b'UTF-8' + self._ref_tracker = StatesReferenceTracker() + self._ref_tracker.track_state(L) if attribute_filter is not None and not callable(attribute_filter): raise ValueError("attribute_filter must be callable") self._attribute_filter = attribute_filter @@ -284,6 +287,7 @@ cdef class LuaRuntime: def __dealloc__(self): if self._state is not NULL: + self._ref_tracker.unref_state_references(self._state) lua.lua_close(self._state) self._state = NULL @@ -724,7 +728,7 @@ cdef class _LuaObject: cdef lua_State* L = self._state if L is not NULL and self._ref != lua.LUA_NOREF: locked = lock_runtime(self._runtime) - lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, self._ref) + self._runtime._ref_tracker.unref_reference(L, self._ref) self._ref = lua.LUA_NOREF runtime = self._runtime self._runtime = None @@ -875,6 +879,7 @@ cdef void init_lua_object(_LuaObject obj, LuaRuntime runtime, lua_State* L, int obj._state = L lua.lua_pushvalue(L, n) obj._ref = lua.luaL_ref(L, lua.LUA_REGISTRYINDEX) + runtime._ref_tracker.track_reference(L, obj._ref) cdef object lua_object_repr(lua_State* L, bytes encoding): cdef bytes py_bytes @@ -1100,6 +1105,7 @@ cdef _LuaThread new_lua_thread(LuaRuntime runtime, lua_State* L, int n): cdef _LuaThread obj = _LuaThread.__new__(_LuaThread) init_lua_object(obj, runtime, L, n) obj._co_state = lua.lua_tothread(L, n) + runtime._ref_tracker.track_state(obj._co_state) return obj @@ -1127,6 +1133,7 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): cdef lua_State* L = thread._state cdef int status, i, nargs = 0, nres = 0 assert thread._runtime is not None + cdef bint done = False lock_runtime(thread._runtime) old_top = lua.lua_gettop(L) try: @@ -1144,8 +1151,10 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): # terminated if nres == 0: # no values left to return + done = True raise StopIteration else: + done = True raise_lua_error(thread._runtime, co, status) # Move yielded values to the main state before unpacking. @@ -1156,6 +1165,9 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): finally: # FIXME: check that coroutine state is OK in case of errors? lua.lua_settop(L, old_top) + if done: + # unref all state references + thread._runtime._ref_tracker.unref_state_references(co) unlock_runtime(thread._runtime) @@ -1746,6 +1758,54 @@ cdef int py_object_gc(lua_State* L) nogil: return lua.lua_error(L) # never returns! return 0 +# ref-counting support for lua objects + +@cython.final +@cython.internal +@cython.no_gc_clear +cdef class StatesReferenceTracker: + """Track all lua state objects and their references.""" + + cdef dict _states_references # Dict[LuaState, Set[int]] + + def __cinit__(self): + self._states_references = {} + + cdef int track_state(self, lua_State *L) except -1: + cdef state = L + assert state not in self._states_references + self._states_references[state] = set() + return 0 + + cdef int unref_state_references(self, lua_State *L) except -1: + cdef set references = self._states_references.pop(L, None) + if not references: + return 0 + + # unref all tracked references + for ref in references: + lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, ref) + return 0 + + cdef int track_reference(self, lua_State *L, int ref) except -1: + cdef set references = self._states_references.get(L, None) + if references is None: + return 0 + + references.add(ref) + return 0 + + cdef int unref_reference(self, lua_State *L, int ref) except -1: + cdef set references = self._states_references.get(L, None) + if references is None: + return 0 + + if ref in references: + references.remove(ref) + lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, ref) + + return 0 + # calling Python objects cdef bint call_python(LuaRuntime runtime, lua_State *L, py_object* py_obj) except -1: diff --git a/lupa/tests/test.py b/lupa/tests/test.py index d8706919..6472ed06 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -2972,6 +2972,51 @@ def test_bad_tostring(self): def test_tostring_err(self): self.assertRaises(lupa.LuaError, str, self.lua.eval('setmetatable({}, {__tostring = function() error() end})')) +class TestSigSegScenarios(unittest.TestCase): + class PendingRequest: + + def __init__(self, callback): + self.__callback = callback + + def make_request(callback): + return TestSigSegScenarios.PendingRequest(callback) + + def test_callback_passing(self): + lua = lupa.LuaRuntime() + lua.globals().make_request = TestSigSegScenarios.make_request + run = lua.eval(""" + function() + make_request(function() end) + end + """) + + for i in range(10000): + thread = run.coroutine() + try: + thread.send(None) + except StopIteration: + pass + + # assert no segmentation fault + + def test_callback_passing_with_exception(self): + lua = lupa.LuaRuntime() + lua.globals().make_request = TestSigSegScenarios.make_request + run = lua.eval(""" + function() + make_request(function() end) + error('test error') + end + """) + + for i in range(10000): + thread = run.coroutine() + try: + thread.send(None) + except Exception: + pass + + # assert no segmentation fault if __name__ == '__main__': def print_version(): From edbcd62168229147fa583ab1d8d24fb93eaa9073 Mon Sep 17 00:00:00 2001 From: Adam Bella Date: Tue, 31 Aug 2021 21:19:09 +0200 Subject: [PATCH 2/4] fix tests for python 2.7 --- lupa/tests/test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lupa/tests/test.py b/lupa/tests/test.py index 6472ed06..7f720135 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -2973,17 +2973,17 @@ def test_tostring_err(self): self.assertRaises(lupa.LuaError, str, self.lua.eval('setmetatable({}, {__tostring = function() error() end})')) class TestSigSegScenarios(unittest.TestCase): - class PendingRequest: + class PendingRequest(object): def __init__(self, callback): self.__callback = callback - def make_request(callback): + def make_request(self, callback): return TestSigSegScenarios.PendingRequest(callback) def test_callback_passing(self): lua = lupa.LuaRuntime() - lua.globals().make_request = TestSigSegScenarios.make_request + lua.globals().make_request = self.make_request run = lua.eval(""" function() make_request(function() end) @@ -3001,7 +3001,7 @@ def test_callback_passing(self): def test_callback_passing_with_exception(self): lua = lupa.LuaRuntime() - lua.globals().make_request = TestSigSegScenarios.make_request + lua.globals().make_request = self.make_request run = lua.eval(""" function() make_request(function() end) From 422af1992d31048fa61d8143665b9f6136ce5726 Mon Sep 17 00:00:00 2001 From: Adam Bella Date: Fri, 3 Sep 2021 18:08:46 +0200 Subject: [PATCH 3/4] fix unrefereing all objects created in coroutine Unreferencing all objects created in coroutine is not good solution because such objects can be used outside of coroutine scope. This fix refactor all reference tracking to tracking objects and when coroutine is done, set state of all objects that was created in coroutine to parent lua state. This way objects created in coroutine can be used outside of coroutine scope and it does not cause segmentation fault. --- lupa/_lupa.pyx | 92 ++++++++++++++++++++++++++++------------------ lupa/tests/test.py | 12 +++--- 2 files changed, 61 insertions(+), 43 deletions(-) diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index f2f58c9f..7f12f79c 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -21,6 +21,8 @@ from cpython.method cimport ( PyMethod_Check, PyMethod_GET_SELF, PyMethod_GET_FUNCTION) from cpython.bytes cimport PyBytes_FromFormat +import weakref + #from libc.stdint cimport uintptr_t cdef extern from *: """ @@ -242,7 +244,7 @@ cdef class LuaRuntime: cdef object _attribute_getter cdef object _attribute_setter cdef bint _unpack_returned_tuples - cdef StatesReferenceTracker _ref_tracker + cdef CostatesObjTracker _obj_tracker def __cinit__(self, encoding='UTF-8', source_encoding=None, attribute_filter=None, attribute_handlers=None, @@ -256,8 +258,7 @@ cdef class LuaRuntime: self._pyrefs_in_lua = {} self._encoding = _asciiOrNone(encoding) self._source_encoding = _asciiOrNone(source_encoding) or self._encoding or b'UTF-8' - self._ref_tracker = StatesReferenceTracker() - self._ref_tracker.track_state(L) + self._obj_tracker = CostatesObjTracker() if attribute_filter is not None and not callable(attribute_filter): raise ValueError("attribute_filter must be callable") self._attribute_filter = attribute_filter @@ -287,7 +288,6 @@ cdef class LuaRuntime: def __dealloc__(self): if self._state is not NULL: - self._ref_tracker.unref_state_references(self._state) lua.lua_close(self._state) self._state = NULL @@ -715,6 +715,7 @@ cdef class _LuaObject: cdef LuaRuntime _runtime cdef lua_State* _state cdef int _ref + cdef object __weakref__ def __cinit__(self): self._ref = lua.LUA_NOREF @@ -728,7 +729,8 @@ cdef class _LuaObject: cdef lua_State* L = self._state if L is not NULL and self._ref != lua.LUA_NOREF: locked = lock_runtime(self._runtime) - self._runtime._ref_tracker.unref_reference(L, self._ref) + if self._runtime._obj_tracker.untrack_obj(self): + lua.luaL_unref(self._state, lua.LUA_REGISTRYINDEX, self._ref) self._ref = lua.LUA_NOREF runtime = self._runtime self._runtime = None @@ -879,7 +881,7 @@ cdef void init_lua_object(_LuaObject obj, LuaRuntime runtime, lua_State* L, int obj._state = L lua.lua_pushvalue(L, n) obj._ref = lua.luaL_ref(L, lua.LUA_REGISTRYINDEX) - runtime._ref_tracker.track_reference(L, obj._ref) + runtime._obj_tracker.track_obj(obj) cdef object lua_object_repr(lua_State* L, bytes encoding): cdef bytes py_bytes @@ -1063,6 +1065,16 @@ cdef class _LuaThread(_LuaObject): """ cdef lua_State* _co_state cdef tuple _arguments + + def __dealloc__(self): + cdef LuaRuntime runtime = self._runtime + # set states of all unreferenced coroutine objects to parent state + # because coroutine state no longer exists + costate_objs = runtime._obj_tracker.untrack_costate_objects(self._co_state) + if costate_objs: + for obj in costate_objs: + (<_LuaObject>obj)._state = self._state + def __iter__(self): return self @@ -1105,7 +1117,7 @@ cdef _LuaThread new_lua_thread(LuaRuntime runtime, lua_State* L, int n): cdef _LuaThread obj = _LuaThread.__new__(_LuaThread) init_lua_object(obj, runtime, L, n) obj._co_state = lua.lua_tothread(L, n) - runtime._ref_tracker.track_state(obj._co_state) + runtime._obj_tracker.track_state(obj._co_state) return obj @@ -1166,8 +1178,12 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): # FIXME: check that coroutine state is OK in case of errors? lua.lua_settop(L, old_top) if done: - # unref all state references - thread._runtime._ref_tracker.unref_state_references(co) + # set states of all unreferenced coroutine objects to parent state + # because coroutine state no longer exists + costate_objs = thread._runtime._obj_tracker.untrack_costate_objects(co) + if costate_objs: + for obj in costate_objs: + (<_LuaObject>obj)._state = L unlock_runtime(thread._runtime) @@ -1763,46 +1779,50 @@ cdef int py_object_gc(lua_State* L) nogil: @cython.final @cython.internal @cython.no_gc_clear -cdef class StatesReferenceTracker: - """Track all lua state objects and their references.""" +cdef class CostatesObjTracker: + """Track all coroutine state objects. - cdef dict _states_references # Dict[LuaState, Set[int]] + Uses weak object references so that gc can make cleanup.""" + + cdef dict _costates_objs # Dict[LuaState, weakref.WeakSet[_LuaObject]] def __cinit__(self): - self._states_references = {} + self._costates_objs = {} cdef int track_state(self, lua_State *L) except -1: + """Add new coroutine state to track its object""" cdef state = L - assert state not in self._states_references - self._states_references[state] = set() + assert state not in self._costates_objs + self._costates_objs[state] = weakref.WeakSet() return 0 - cdef int unref_state_references(self, lua_State *L) except -1: - cdef set references = self._states_references.pop(L, None) - if not references: - return 0 - - # unref all tracked references - for ref in references: - lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, ref) - return 0 + cdef object untrack_costate_objects(self, lua_State *L): + """Remove coroutine state and return its objects that has not been untracked yet.""" + cdef state = L + return self._costates_objs.pop(state, None) - cdef int track_reference(self, lua_State *L, int ref) except -1: - cdef set references = self._states_references.get(L, None) - if references is None: + cdef int track_obj(self, _LuaObject obj) except -1: + """Add object reference if it belongs to tracked coroutine state.""" + cdef state = obj._state + cdef state_objs = self._costates_objs.get(state, None) + if state_objs is None: return 0 - references.add(ref) - return 0 + state_objs.add(obj) + return 1 - cdef int unref_reference(self, lua_State *L, int ref) except -1: - cdef set references = self._states_references.get(L, None) - if references is None: - return 0 + cdef int untrack_obj(self, _LuaObject obj) except -1: + """Remove object reference if it belong to tracked coroutine state. + Returns 1 if object needs to be unreferenced in lua. + """ + cdef state = obj._state + cdef state_objs = self._costates_objs.get(state, None) + if state_objs is None: + return 1 - if ref in references: - references.remove(ref) - lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, ref) + if obj in state_objs: + state_objs.remove(obj) + return 1 return 0 diff --git a/lupa/tests/test.py b/lupa/tests/test.py index 7f720135..bace4091 100644 --- a/lupa/tests/test.py +++ b/lupa/tests/test.py @@ -2972,7 +2972,7 @@ def test_bad_tostring(self): def test_tostring_err(self): self.assertRaises(lupa.LuaError, str, self.lua.eval('setmetatable({}, {__tostring = function() error() end})')) -class TestSigSegScenarios(unittest.TestCase): +class TestSigSegScenarios(SetupLuaRuntimeMixin, unittest.TestCase): class PendingRequest(object): def __init__(self, callback): @@ -2982,9 +2982,8 @@ def make_request(self, callback): return TestSigSegScenarios.PendingRequest(callback) def test_callback_passing(self): - lua = lupa.LuaRuntime() - lua.globals().make_request = self.make_request - run = lua.eval(""" + self.lua.globals().make_request = self.make_request + run = self.lua.eval(""" function() make_request(function() end) end @@ -3000,9 +2999,8 @@ def test_callback_passing(self): # assert no segmentation fault def test_callback_passing_with_exception(self): - lua = lupa.LuaRuntime() - lua.globals().make_request = self.make_request - run = lua.eval(""" + self.lua.globals().make_request = self.make_request + run = self.lua.eval(""" function() make_request(function() end) error('test error') From 17dd7eac040acab187760d6bfd24c2831f1a9cd2 Mon Sep 17 00:00:00 2001 From: Adam Bella Date: Tue, 7 Sep 2021 10:26:52 +0200 Subject: [PATCH 4/4] segmentation fault fix - simplified version This fix use runtime state for unreferencing LuaObject because using coroutine state of already closed coroutine cause segmentation fault. --- lupa/_lupa.pyx | 84 ++------------------------------------------------ 1 file changed, 2 insertions(+), 82 deletions(-) diff --git a/lupa/_lupa.pyx b/lupa/_lupa.pyx index 7f12f79c..4a4841f7 100644 --- a/lupa/_lupa.pyx +++ b/lupa/_lupa.pyx @@ -21,8 +21,6 @@ from cpython.method cimport ( PyMethod_Check, PyMethod_GET_SELF, PyMethod_GET_FUNCTION) from cpython.bytes cimport PyBytes_FromFormat -import weakref - #from libc.stdint cimport uintptr_t cdef extern from *: """ @@ -244,7 +242,6 @@ cdef class LuaRuntime: cdef object _attribute_getter cdef object _attribute_setter cdef bint _unpack_returned_tuples - cdef CostatesObjTracker _obj_tracker def __cinit__(self, encoding='UTF-8', source_encoding=None, attribute_filter=None, attribute_handlers=None, @@ -258,7 +255,6 @@ cdef class LuaRuntime: self._pyrefs_in_lua = {} self._encoding = _asciiOrNone(encoding) self._source_encoding = _asciiOrNone(source_encoding) or self._encoding or b'UTF-8' - self._obj_tracker = CostatesObjTracker() if attribute_filter is not None and not callable(attribute_filter): raise ValueError("attribute_filter must be callable") self._attribute_filter = attribute_filter @@ -715,7 +711,6 @@ cdef class _LuaObject: cdef LuaRuntime _runtime cdef lua_State* _state cdef int _ref - cdef object __weakref__ def __cinit__(self): self._ref = lua.LUA_NOREF @@ -726,11 +721,10 @@ cdef class _LuaObject: def __dealloc__(self): if self._runtime is None: return - cdef lua_State* L = self._state + cdef lua_State* L = self._runtime._state if L is not NULL and self._ref != lua.LUA_NOREF: locked = lock_runtime(self._runtime) - if self._runtime._obj_tracker.untrack_obj(self): - lua.luaL_unref(self._state, lua.LUA_REGISTRYINDEX, self._ref) + lua.luaL_unref(L, lua.LUA_REGISTRYINDEX, self._ref) self._ref = lua.LUA_NOREF runtime = self._runtime self._runtime = None @@ -881,7 +875,6 @@ cdef void init_lua_object(_LuaObject obj, LuaRuntime runtime, lua_State* L, int obj._state = L lua.lua_pushvalue(L, n) obj._ref = lua.luaL_ref(L, lua.LUA_REGISTRYINDEX) - runtime._obj_tracker.track_obj(obj) cdef object lua_object_repr(lua_State* L, bytes encoding): cdef bytes py_bytes @@ -1065,16 +1058,6 @@ cdef class _LuaThread(_LuaObject): """ cdef lua_State* _co_state cdef tuple _arguments - - def __dealloc__(self): - cdef LuaRuntime runtime = self._runtime - # set states of all unreferenced coroutine objects to parent state - # because coroutine state no longer exists - costate_objs = runtime._obj_tracker.untrack_costate_objects(self._co_state) - if costate_objs: - for obj in costate_objs: - (<_LuaObject>obj)._state = self._state - def __iter__(self): return self @@ -1117,7 +1100,6 @@ cdef _LuaThread new_lua_thread(LuaRuntime runtime, lua_State* L, int n): cdef _LuaThread obj = _LuaThread.__new__(_LuaThread) init_lua_object(obj, runtime, L, n) obj._co_state = lua.lua_tothread(L, n) - runtime._obj_tracker.track_state(obj._co_state) return obj @@ -1145,7 +1127,6 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): cdef lua_State* L = thread._state cdef int status, i, nargs = 0, nres = 0 assert thread._runtime is not None - cdef bint done = False lock_runtime(thread._runtime) old_top = lua.lua_gettop(L) try: @@ -1163,10 +1144,8 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): # terminated if nres == 0: # no values left to return - done = True raise StopIteration else: - done = True raise_lua_error(thread._runtime, co, status) # Move yielded values to the main state before unpacking. @@ -1177,13 +1156,6 @@ cdef object resume_lua_thread(_LuaThread thread, tuple args): finally: # FIXME: check that coroutine state is OK in case of errors? lua.lua_settop(L, old_top) - if done: - # set states of all unreferenced coroutine objects to parent state - # because coroutine state no longer exists - costate_objs = thread._runtime._obj_tracker.untrack_costate_objects(co) - if costate_objs: - for obj in costate_objs: - (<_LuaObject>obj)._state = L unlock_runtime(thread._runtime) @@ -1774,58 +1746,6 @@ cdef int py_object_gc(lua_State* L) nogil: return lua.lua_error(L) # never returns! return 0 -# ref-counting support for lua objects - -@cython.final -@cython.internal -@cython.no_gc_clear -cdef class CostatesObjTracker: - """Track all coroutine state objects. - - Uses weak object references so that gc can make cleanup.""" - - cdef dict _costates_objs # Dict[LuaState, weakref.WeakSet[_LuaObject]] - - def __cinit__(self): - self._costates_objs = {} - - cdef int track_state(self, lua_State *L) except -1: - """Add new coroutine state to track its object""" - cdef state = L - assert state not in self._costates_objs - self._costates_objs[state] = weakref.WeakSet() - return 0 - - cdef object untrack_costate_objects(self, lua_State *L): - """Remove coroutine state and return its objects that has not been untracked yet.""" - cdef state = L - return self._costates_objs.pop(state, None) - - cdef int track_obj(self, _LuaObject obj) except -1: - """Add object reference if it belongs to tracked coroutine state.""" - cdef state = obj._state - cdef state_objs = self._costates_objs.get(state, None) - if state_objs is None: - return 0 - - state_objs.add(obj) - return 1 - - cdef int untrack_obj(self, _LuaObject obj) except -1: - """Remove object reference if it belong to tracked coroutine state. - Returns 1 if object needs to be unreferenced in lua. - """ - cdef state = obj._state - cdef state_objs = self._costates_objs.get(state, None) - if state_objs is None: - return 1 - - if obj in state_objs: - state_objs.remove(obj) - return 1 - - return 0 - # calling Python objects cdef bint call_python(LuaRuntime runtime, lua_State *L, py_object* py_obj) except -1: