Skip to content

Commit f993841

Browse files
committed
Fix unit tests
1 parent ea677c3 commit f993841

File tree

3 files changed

+71
-45
lines changed

3 files changed

+71
-45
lines changed

include/pybind11/detail/internals.h

+55-35
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ inline PyThreadState *get_thread_state_unchecked() {
270270

271271
/// We use this count to figure out if there are multiple sub-interpreters currently present.
272272
/// Might be read/written from multiple interpreters in multiple threads at the same time with no
273-
/// syncronization. This must never decrease while any interpreter may be running in any thread!
274-
inline std::atomic<int64_t> &get_interpreter_count() {
275-
static std::atomic<int64_t> counter(0);
273+
/// synchronization. This must never decrease while any interpreter may be running in any thread!
274+
inline std::atomic<int> &get_interpreter_count() {
275+
static std::atomic<int> counter(0);
276276
return counter;
277277
}
278278

@@ -444,25 +444,9 @@ inline uint64_t round_up_to_next_pow2(uint64_t x) {
444444
}
445445

446446
template <typename InternalsType>
447-
inline InternalsType **find_or_create_internals_pp(char const *state_dict_key) {
448-
#if defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
449-
gil_scoped_acquire gil;
450-
#else
451-
// Ensure that the GIL is held since we will need to make Python calls.
452-
// Cannot use py::gil_scoped_acquire here since that constructor calls get_internals.
453-
struct gil_scoped_acquire_local {
454-
gil_scoped_acquire_local() : state(PyGILState_Ensure()) {}
455-
gil_scoped_acquire_local(const gil_scoped_acquire_local &) = delete;
456-
gil_scoped_acquire_local &operator=(const gil_scoped_acquire_local &) = delete;
457-
~gil_scoped_acquire_local() { PyGILState_Release(state); }
458-
const PyGILState_STATE state;
459-
} gil;
460-
#endif
461-
462-
error_scope err_scope;
463-
447+
inline InternalsType **find_internals_pp(char const *state_dict_key) {
464448
dict state_dict = get_python_state_dict();
465-
object internals_obj
449+
auto internals_obj
466450
= reinterpret_steal<object>(dict_getitemstringref(state_dict.ptr(), state_dict_key));
467451
if (internals_obj) {
468452
void *raw_ptr = PyCapsule_GetPointer(internals_obj.ptr(), /*name=*/nullptr);
@@ -472,21 +456,38 @@ inline InternalsType **find_or_create_internals_pp(char const *state_dict_key) {
472456
return reinterpret_cast<InternalsType **>(raw_ptr);
473457
}
474458
}
475-
476-
auto pp = new InternalsType *(nullptr);
477-
state_dict[state_dict_key] = capsule(reinterpret_cast<void *>(pp));
478-
return pp;
459+
return nullptr;
479460
}
480461

462+
#if defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
463+
using simple_gil_scoped_acquire = gil_scoped_acquire;
464+
#else
465+
// Cannot use py::gil_scoped_acquire inside get_internals since that calls get_internals.
466+
struct simple_gil_scoped_acquire {
467+
simple_gil_scoped_acquire() : state(PyGILState_Ensure()) {}
468+
simple_gil_scoped_acquire(const simple_gil_scoped_acquire &) = delete;
469+
simple_gil_scoped_acquire &operator=(const simple_gil_scoped_acquire &) = delete;
470+
simple_gil_scoped_acquire(simple_gil_scoped_acquire &&) = delete;
471+
simple_gil_scoped_acquire &operator=(simple_gil_scoped_acquire &&) = delete;
472+
~simple_gil_scoped_acquire() { PyGILState_Release(state); }
473+
const PyGILState_STATE state;
474+
};
475+
#endif
476+
481477
/// Return a reference to the current `internals` data
482478
PYBIND11_NOINLINE internals &get_internals() {
483479
auto **&internals_pp = get_internals_pp();
484480
if (internals_pp && *internals_pp) {
485481
return **internals_pp;
486482
}
487483

488-
internals_pp = find_or_create_internals_pp<internals>(PYBIND11_INTERNALS_ID);
489-
if (*internals_pp) {
484+
simple_gil_scoped_acquire gil;
485+
486+
error_scope err_scope;
487+
488+
internals_pp = find_internals_pp<internals>(PYBIND11_INTERNALS_ID);
489+
490+
if (internals_pp && *internals_pp) {
490491
// We loaded the internals through `state_dict`, which means that our `error_already_set`
491492
// and `builtin_exception` may be different local classes than the ones set up in the
492493
// initial exception translator, below, so add another for our local exception classes.
@@ -503,10 +504,16 @@ PYBIND11_NOINLINE internals &get_internals() {
503504
}
504505
#endif
505506
} else {
507+
if (!internals_pp) {
508+
internals_pp = new internals *(nullptr);
509+
dict state = get_python_state_dict();
510+
state[PYBIND11_INTERNALS_ID] = capsule(reinterpret_cast<void *>(internals_pp));
511+
}
512+
506513
auto *&internals_ptr = *internals_pp;
507514
internals_ptr = new internals();
508515

509-
PyThreadState *tstate = get_thread_state_unchecked();
516+
PyThreadState *tstate = PyThreadState_Get();
510517
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
511518
if (!PYBIND11_TLS_KEY_CREATE(internals_ptr->tstate)) {
512519
pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!");
@@ -572,24 +579,37 @@ inline local_internals **&get_local_internals_pp() {
572579
return s_internals_pp;
573580
}
574581

582+
inline char const *get_local_internals_id() {
583+
// we create a unique value at first run which is based on a pointer to a (non-thread_local)
584+
// static value in this function, then multiple loaded modules using this code will still each
585+
// have a unique key.
586+
static const std::string this_module_idstr
587+
= PYBIND11_MODULE_LOCAL_ID
588+
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
589+
return this_module_idstr.c_str();
590+
}
591+
575592
/// Works like `get_internals`, but for things which are locally registered.
576593
inline local_internals &get_local_internals() {
577594
auto **&local_internals_pp = get_local_internals_pp();
578595
if (local_internals_pp && *local_internals_pp) {
579596
return **local_internals_pp;
580597
}
581598

582-
// we create a unique value at first run which is based on a pointer to a (non-thread_local)
583-
// static value in this function, then multiple loaded modules using this code will still each
584-
// have a unique key.
585-
static const std::string this_module_idstr
586-
= PYBIND11_MODULE_LOCAL_ID
587-
+ std::to_string(reinterpret_cast<uintptr_t>(&this_module_idstr));
599+
simple_gil_scoped_acquire gil;
588600

589-
local_internals_pp = find_or_create_internals_pp<local_internals>(this_module_idstr.c_str());
601+
error_scope err_scope;
602+
603+
local_internals_pp = find_internals_pp<local_internals>(get_local_internals_id());
604+
if (!local_internals_pp) {
605+
local_internals_pp = new local_internals *(nullptr);
606+
dict state = get_python_state_dict();
607+
state[get_local_internals_id()] = capsule(reinterpret_cast<void *>(local_internals_pp));
608+
}
590609
if (!*local_internals_pp) {
591610
*local_internals_pp = new local_internals();
592611
}
612+
593613
return **local_internals_pp;
594614
}
595615

include/pybind11/embed.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,23 @@ inline void finalize_interpreter() {
236236
// Get the internals pointer (without creating it if it doesn't exist). It's possible for the
237237
// internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
238238
// during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
239-
auto **&internals_pp = detail::get_internals_pp();
239+
auto **&internals_ptr_ptr = detail::get_internals_pp();
240+
internals_ptr_ptr = detail::find_internals_pp<detail::internals>(PYBIND11_INTERNALS_ID);
240241
auto **&local_internals_pp = detail::get_local_internals_pp();
242+
local_internals_pp = detail::find_internals_pp<detail::local_internals>(detail::get_local_internals_id());
241243

242244
Py_Finalize();
243245

244-
if (internals_pp) {
245-
delete *internals_pp;
246-
*internals_pp = nullptr;
246+
if (internals_ptr_ptr) {
247+
delete *internals_ptr_ptr;
248+
*internals_ptr_ptr = nullptr;
247249
}
248250

249251
// Local internals contains data managed by the current interpreter, so we must clear them to
250252
// avoid undefined behaviors when initializing another interpreter
251-
if (local_internals_pp) {
253+
if (local_internals_pp && *local_internals_pp) {
252254
delete *local_internals_pp;
253-
local_internals_pp = nullptr;
255+
*local_internals_pp = nullptr;
254256
}
255257

256258
detail::get_interpreter_count() = 0;

tests/test_embed/test_interpreter.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ TEST_CASE("Add program dir to path using PyConfig") {
256256
#endif
257257

258258
bool has_state_dict_internals_obj() {
259-
return bool(
260-
py::detail::get_internals_obj_from_state_dict(py::detail::get_python_state_dict()));
259+
py::dict state = py::detail::get_python_state_dict();
260+
return state.contains(PYBIND11_INTERNALS_ID);
261261
}
262262

263263
bool has_pybind11_internals_static() {
@@ -308,6 +308,7 @@ TEST_CASE("Restart the interpreter") {
308308
REQUIRE_FALSE(ran);
309309
py::finalize_interpreter();
310310
REQUIRE(ran);
311+
REQUIRE_FALSE(has_pybind11_internals_static());
311312
py::initialize_interpreter();
312313
REQUIRE_FALSE(has_state_dict_internals_obj());
313314
REQUIRE_FALSE(has_pybind11_internals_static());
@@ -340,15 +341,16 @@ TEST_CASE("Subinterpreter") {
340341
auto *main_tstate = PyThreadState_Get();
341342
auto *sub_tstate = Py_NewInterpreter();
342343

344+
py::detail::get_interpreter_count()++;
345+
343346
// Subinterpreters get their own copy of builtins.
344347
REQUIRE_FALSE(has_state_dict_internals_obj());
345348

346349
#if defined(PYBIND11_SUBINTERPRETER_SUPPORT) && PY_VERSION_HEX >= 0x030C0000
347350
// internals hasn't been populated yet, but will be different for the subinterpreter
348351
REQUIRE_FALSE(has_pybind11_internals_static());
349352

350-
py::list sys_path = py::module_::import("sys").attr("path");
351-
sys_path.append(py::str("."));
353+
py::list(py::module_::import("sys").attr("path")).append(py::str("."));
352354

353355
auto ext_int = py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>();
354356
py::detail::get_internals();
@@ -374,11 +376,13 @@ TEST_CASE("Subinterpreter") {
374376

375377
// Restore main interpreter.
376378
Py_EndInterpreter(sub_tstate);
379+
py::detail::get_interpreter_count() = 1;
377380
PyThreadState_Swap(main_tstate);
378381

379382
REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
380383
REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
381384
REQUIRE(has_state_dict_internals_obj());
385+
382386
}
383387

384388
TEST_CASE("Execution frame") {

0 commit comments

Comments
 (0)