Skip to content

Commit 9955bcc

Browse files
committed
Attach python lifetime to shared_ptr passed to C++
- Reference cycles are possible as a result, but shared_ptr is already susceptible to this in C++
1 parent 721834b commit 9955bcc

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

include/pybind11/cast.h

+31
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,9 @@ struct copyable_holder_caster : public type_caster_base<type> {
15631563
throw cast_error("Unable to load a custom holder type from a default-holder instance");
15641564
}
15651565

1566+
// holders that are not std::shared_ptr
1567+
template <typename T = holder_type,
1568+
detail::enable_if_t<!is_shared_ptr<T>::value, int> = 0>
15661569
bool load_value(value_and_holder &&v_h) {
15671570
if (v_h.holder_constructed()) {
15681571
value = v_h.value_ptr();
@@ -1578,6 +1581,34 @@ struct copyable_holder_caster : public type_caster_base<type> {
15781581
}
15791582
}
15801583

1584+
// holders that are std::shared_ptr
1585+
template <typename T = holder_type,
1586+
detail::enable_if_t<is_shared_ptr<T>::value, int> = 0>
1587+
bool load_value(value_and_holder &&v_h) {
1588+
if (v_h.holder_constructed()) {
1589+
value = v_h.value_ptr();
1590+
1591+
// The shared_ptr is always given to C++ code, so we construct a new shared_ptr
1592+
// that is given a custom deleter. The custom deleter increments the python
1593+
// reference count to bind the python instance lifetime with the lifetime
1594+
// of the shared_ptr.
1595+
//
1596+
// This enables things like passing the last python reference of a subclass to a
1597+
// C++ function without the python reference dying.
1598+
//
1599+
// Reference cycles will cause a leak, but this is a limitation of shared_ptr
1600+
holder = holder_type((type*)value, [unused = reinterpret_borrow<object>((PyObject*)v_h.inst)](type*){});
1601+
return true;
1602+
} else {
1603+
throw cast_error("Unable to cast from non-held to held instance (T& to Holder<T>) "
1604+
#if defined(NDEBUG)
1605+
"(compile in debug mode for type information)");
1606+
#else
1607+
"of type '" + type_id<holder_type>() + "''");
1608+
#endif
1609+
}
1610+
}
1611+
15811612
template <typename T = holder_type, detail::enable_if_t<!std::is_constructible<T, const T &, type*>::value, int> = 0>
15821613
bool try_implicit_casts(handle, bool) { return false; }
15831614

tests/test_smart_ptr.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -397,4 +397,33 @@ TEST_SUBMODULE(smart_ptr, m) {
397397
list.append(py::cast(e));
398398
return list;
399399
});
400+
401+
// For testing whether a python subclass of a C++ object dies when the
402+
// last python reference is lost
403+
struct SpBase {
404+
// returns true if the base virtual function is called
405+
virtual bool is_base_used() { return true; }
406+
virtual ~SpBase() = default;
407+
};
408+
409+
struct PySpBase : SpBase {
410+
bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); }
411+
};
412+
413+
struct SpBaseTester {
414+
std::shared_ptr<SpBase> get_object() { return m_obj; }
415+
void set_object(std::shared_ptr<SpBase> obj) { m_obj = obj; }
416+
bool is_base_used() { return m_obj->is_base_used(); }
417+
std::shared_ptr<SpBase> m_obj;
418+
};
419+
420+
py::class_<SpBase, std::shared_ptr<SpBase>, PySpBase>(m, "SpBase")
421+
.def(py::init<>())
422+
.def("is_base_used", &SpBase::is_base_used);
423+
424+
py::class_<SpBaseTester>(m, "SpBaseTester")
425+
.def(py::init<>())
426+
.def("get_object", &SpBaseTester::get_object)
427+
.def("set_object", &SpBaseTester::set_object)
428+
.def("is_base_used", &SpBaseTester::is_base_used);
400429
}

tests/test_smart_ptr.py

+44
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,47 @@ def test_shared_ptr_gc():
316316
pytest.gc_collect()
317317
for i, v in enumerate(el.get()):
318318
assert i == v.value()
319+
320+
321+
def test_shared_ptr_cpp_arg():
322+
import weakref
323+
324+
class PyChild(m.SpBase):
325+
def is_base_used(self):
326+
return False
327+
328+
tester = m.SpBaseTester()
329+
330+
obj = PyChild()
331+
objref = weakref.ref(obj)
332+
333+
tester.set_object(obj)
334+
del obj
335+
pytest.gc_collect()
336+
337+
# python reference is still around since C++ has it now
338+
assert objref() is not None
339+
assert tester.is_base_used() == False
340+
assert tester.get_object() is objref()
341+
342+
343+
def test_shared_ptr_arg_identity():
344+
import weakref
345+
346+
tester = m.SpBaseTester()
347+
348+
obj = m.SpBase()
349+
objref = weakref.ref(obj)
350+
351+
tester.set_object(obj)
352+
del obj
353+
pytest.gc_collect()
354+
355+
# python reference is still around since C++ has it
356+
assert objref() is not None
357+
assert tester.get_object() is objref()
358+
359+
# python reference disappears once the C++ object releases it
360+
tester.set_object(None)
361+
pytest.gc_collect()
362+
assert objref() is None

0 commit comments

Comments
 (0)