Skip to content

Commit 05e6e2d

Browse files
committed
override: Fix wrong caching of the overrides
There was a problem when the python type, which was stored in override cache for C++ functions, was destroyed and the record wasn't removed from the override cache. Therefor, dangling pointer was stored there. Then when the memory was reused and new type was allocated at the given address and the method with the same name (as previously stored in the cache) was actually overridden in python, it would wrongly find it in the override cache for C++ functions and therefor override from python wouldn't be called. The fix is to erase the type from the override cache when the type is destroyed.
1 parent 58c7f07 commit 05e6e2d

File tree

6 files changed

+123
-1
lines changed

6 files changed

+123
-1
lines changed

include/pybind11/pybind11.h

+10
Original file line numberDiff line numberDiff line change
@@ -2093,6 +2093,16 @@ inline std::pair<decltype(internals::registered_types_py)::iterator, bool> all_t
20932093
// gets destroyed:
20942094
weakref((PyObject *) type, cpp_function([type](handle wr) {
20952095
get_internals().registered_types_py.erase(type);
2096+
2097+
// Actually just `std::erase_if`, but that's only available in C++20
2098+
auto &cache = get_internals().inactive_override_cache;
2099+
for (auto it = cache.begin(), last = cache.end(); it != last; ) {
2100+
if (it->first == reinterpret_cast<PyObject *>(type))
2101+
it = cache.erase(it);
2102+
else
2103+
++it;
2104+
}
2105+
20962106
wr.dec_ref();
20972107
})).release();
20982108
}

tests/test_class_sh_inheritance.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ struct drvd2 : base1, base2 {
4949
int id() const override { return 3 * base1::base_id + 4 * base2::base_id; }
5050
};
5151

52+
class test_derived {
53+
54+
public:
55+
virtual int func() { return 0; }
56+
57+
test_derived() = default;
58+
~test_derived() = default;
59+
// Non-copyable
60+
test_derived &operator=(test_derived const &Right) = delete;
61+
test_derived(test_derived const &Copy) = delete;
62+
};
63+
64+
class py_test_derived : public test_derived {
65+
virtual int func() override { PYBIND11_OVERRIDE(int, test_derived, func); }
66+
};
67+
68+
inline int test_override_cache(std::shared_ptr < test_derived> instance) { return instance->func(); }
69+
5270
// clang-format off
5371
inline drvd2 *rtrn_mptr_drvd2() { return new drvd2; }
5472
inline base1 *rtrn_mptr_drvd2_up_cast1() { return new drvd2; }
@@ -69,6 +87,8 @@ PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::base1)
6987
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::base2)
7088
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::drvd2)
7189

90+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::test_derived)
91+
7292
namespace pybind11_tests {
7393
namespace class_sh_inheritance {
7494

@@ -99,6 +119,12 @@ TEST_SUBMODULE(class_sh_inheritance, m) {
99119
m.def("pass_cptr_base1", pass_cptr_base1);
100120
m.def("pass_cptr_base2", pass_cptr_base2);
101121
m.def("pass_cptr_drvd2", pass_cptr_drvd2);
122+
123+
py::classh<test_derived, py_test_derived>(m, "test_derived")
124+
.def(py::init_alias<>())
125+
.def("func", &test_derived::func);
126+
127+
m.def("test_override_cache", test_override_cache);
102128
}
103129

104130
} // namespace class_sh_inheritance

tests/test_class_sh_inheritance.py

+19
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,22 @@ def __init__(self):
6161
assert i1 == 110 + 21
6262
i2 = m.pass_cptr_base2(d)
6363
assert i2 == 120 + 22
64+
65+
66+
def test_python_override():
67+
def func():
68+
class Test(m.test_derived):
69+
def func(self):
70+
return 42
71+
72+
return Test()
73+
74+
def func2():
75+
class Test(m.test_derived):
76+
pass
77+
78+
return Test()
79+
80+
for _ in range(1500):
81+
assert m.test_override_cache(func()) == 42
82+
assert m.test_override_cache(func2()) == 0

tests/test_embed/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pybind11_enable_warnings(test_embed)
2525
target_link_libraries(test_embed PRIVATE pybind11::embed Catch2::Catch2 Threads::Threads)
2626

2727
if(NOT CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR)
28-
file(COPY test_interpreter.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}")
28+
file(COPY test_interpreter.py test_derived.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}")
2929
endif()
3030

3131
add_custom_target(

tests/test_embed/test_derived.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import derived_module
4+
5+
6+
def func():
7+
class Test(derived_module.test_derived):
8+
def func(self):
9+
return 42
10+
11+
return Test()
12+
13+
14+
def func2():
15+
class Test(derived_module.test_derived):
16+
pass
17+
18+
return Test()

tests/test_embed/test_interpreter.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/embed.h>
2+
#include <pybind11/smart_holder.h>
23

34
#ifdef _MSC_VER
45
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
@@ -37,6 +38,24 @@ class PyWidget final : public Widget {
3738
std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
3839
};
3940

41+
class test_derived {
42+
43+
public:
44+
virtual int func() { return 0; }
45+
46+
test_derived() = default;
47+
virtual ~test_derived() = default;
48+
// Non-copyable
49+
test_derived &operator=(test_derived const &Right) = delete;
50+
test_derived(test_derived const &Copy) = delete;
51+
};
52+
53+
class py_test_derived : public test_derived {
54+
virtual int func() override { PYBIND11_OVERRIDE(int, test_derived, func); }
55+
};
56+
57+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(test_derived)
58+
4059
PYBIND11_EMBEDDED_MODULE(widget_module, m) {
4160
py::class_<Widget, PyWidget>(m, "Widget")
4261
.def(py::init<std::string>())
@@ -45,6 +64,12 @@ PYBIND11_EMBEDDED_MODULE(widget_module, m) {
4564
m.def("add", [](int i, int j) { return i + j; });
4665
}
4766

67+
PYBIND11_EMBEDDED_MODULE(derived_module, m) {
68+
py::classh<test_derived, py_test_derived>(m, "test_derived")
69+
.def(py::init_alias<>())
70+
.def("func", &test_derived::func);
71+
}
72+
4873
PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
4974
throw std::runtime_error("C++ Error");
5075
}
@@ -73,6 +98,30 @@ TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
7398
REQUIRE(cpp_widget.the_answer() == 42);
7499
}
75100

101+
TEST_CASE("Override cache") {
102+
auto module_ = py::module_::import("test_derived");
103+
REQUIRE(py::hasattr(module_, "func"));
104+
REQUIRE(py::hasattr(module_, "func2"));
105+
106+
auto locals = py::dict(**module_.attr("__dict__"));
107+
108+
int i = 0;
109+
for (; i < 1500; ++i) {
110+
std::shared_ptr<test_derived> p_obj;
111+
std::shared_ptr<test_derived> p_obj2;
112+
113+
p_obj = pybind11::cast<std::shared_ptr<test_derived>>(locals["func"]());
114+
115+
int ret = p_obj->func();
116+
117+
REQUIRE(ret == 42);
118+
119+
p_obj2 = pybind11::cast<std::shared_ptr<test_derived>>(locals["func2"]());
120+
121+
p_obj2->func();
122+
}
123+
}
124+
76125
TEST_CASE("Import error handling") {
77126
REQUIRE_NOTHROW(py::module_::import("widget_module"));
78127
REQUIRE_THROWS_WITH(py::module_::import("throw_exception"),

0 commit comments

Comments
 (0)