Skip to content

Commit 5b8ef2c

Browse files
Fix how thread states are created (pybind#1276)
Having pybind11 keep its own internal thread state can lead to an inconsistent situation where the Python interpreter has a thread state but pybind does not, and then when gil_scoped_acquire is called, pybind creates a new thread state instead of using the one created by the Python interpreter. This change gets rid of pybind's internal thread state and always uses the one created by the Python interpreter.
1 parent 2d0507d commit 5b8ef2c

File tree

6 files changed

+55
-26
lines changed

6 files changed

+55
-26
lines changed

include/pybind11/detail/internals.h

-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ struct internals {
7979
PyTypeObject *default_metaclass;
8080
PyObject *instance_base;
8181
#if defined(WITH_THREAD)
82-
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
8382
PyInterpreterState *istate = nullptr;
8483
#endif
8584
};
@@ -166,8 +165,6 @@ PYBIND11_NOINLINE inline internals &get_internals() {
166165
#if defined(WITH_THREAD)
167166
PyEval_InitThreads();
168167
PyThreadState *tstate = PyThreadState_Get();
169-
internals_ptr->tstate = PyThread_create_key();
170-
PyThread_set_key_value(internals_ptr->tstate, tstate);
171168
internals_ptr->istate = tstate->interp;
172169
#endif
173170
builtins[id] = capsule(internals_pp);

include/pybind11/pybind11.h

+3-23
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,7 @@ class gil_scoped_acquire {
17551755
public:
17561756
PYBIND11_NOINLINE gil_scoped_acquire() {
17571757
auto const &internals = detail::get_internals();
1758-
tstate = (PyThreadState *) PyThread_get_key_value(internals.tstate);
1758+
tstate = PyGILState_GetThisThreadState();
17591759

17601760
if (!tstate) {
17611761
tstate = PyThreadState_New(internals.istate);
@@ -1764,10 +1764,6 @@ class gil_scoped_acquire {
17641764
pybind11_fail("scoped_acquire: could not create thread state!");
17651765
#endif
17661766
tstate->gilstate_counter = 0;
1767-
#if PY_MAJOR_VERSION < 3
1768-
PyThread_delete_key_value(internals.tstate);
1769-
#endif
1770-
PyThread_set_key_value(internals.tstate, tstate);
17711767
} else {
17721768
release = detail::get_thread_state_unchecked() != tstate;
17731769
}
@@ -1806,7 +1802,6 @@ class gil_scoped_acquire {
18061802
#endif
18071803
PyThreadState_Clear(tstate);
18081804
PyThreadState_DeleteCurrent();
1809-
PyThread_delete_key_value(detail::get_internals().tstate);
18101805
release = false;
18111806
}
18121807
}
@@ -1825,30 +1820,15 @@ class gil_scoped_release {
18251820
public:
18261821
explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) {
18271822
// `get_internals()` must be called here unconditionally in order to initialize
1828-
// `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an
1823+
// `internals.istate` for subsequent `gil_scoped_acquire` calls. Otherwise, an
18291824
// initialization race could occur as multiple threads try `gil_scoped_acquire`.
1830-
const auto &internals = detail::get_internals();
1825+
detail::get_internals();
18311826
tstate = PyEval_SaveThread();
1832-
if (disassoc) {
1833-
auto key = internals.tstate;
1834-
#if PY_MAJOR_VERSION < 3
1835-
PyThread_delete_key_value(key);
1836-
#else
1837-
PyThread_set_key_value(key, nullptr);
1838-
#endif
1839-
}
18401827
}
18411828
~gil_scoped_release() {
18421829
if (!tstate)
18431830
return;
18441831
PyEval_RestoreThread(tstate);
1845-
if (disassoc) {
1846-
auto key = detail::get_internals().tstate;
1847-
#if PY_MAJOR_VERSION < 3
1848-
PyThread_delete_key_value(key);
1849-
#endif
1850-
PyThread_set_key_value(key, tstate);
1851-
}
18521832
}
18531833
private:
18541834
PyThreadState *tstate;

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ set(PYBIND11_TEST_FILES
5757
test_smart_ptr.cpp
5858
test_stl.cpp
5959
test_stl_binders.cpp
60+
test_threads.cpp
6061
test_virtual_functions.cpp
6162
)
6263

tests/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def pytest_namespace():
199199
from pybind11_tests.eigen import have_eigen
200200
except ImportError:
201201
have_eigen = False
202+
try:
203+
import threading
204+
except ImportError:
205+
threading = None
202206
pypy = platform.python_implementation() == "PyPy"
203207

204208
skipif = pytest.mark.skipif
@@ -210,6 +214,7 @@ def pytest_namespace():
210214
reason="eigen and/or numpy are not installed"),
211215
'requires_eigen_and_scipy': skipif(not have_eigen or not scipy,
212216
reason="eigen and/or scipy are not installed"),
217+
'requires_threading': skipif(not threading, reason="no threading"),
213218
'unsupported_on_pypy': skipif(pypy, reason="unsupported on PyPy"),
214219
'unsupported_on_py2': skipif(sys.version_info.major < 3,
215220
reason="unsupported on Python 2.x"),

tests/test_threads.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "pybind11_tests.h"
2+
3+
#if defined(WITH_THREAD)
4+
5+
#include <thread>
6+
7+
static bool check_threadstate() {
8+
return PyGILState_GetThisThreadState() == PyThreadState_Get();
9+
}
10+
11+
TEST_SUBMODULE(threads, m) {
12+
m.def("check_pythread", []() -> bool {
13+
py::gil_scoped_acquire acquire1;
14+
return check_threadstate();
15+
}, py::call_guard<py::gil_scoped_release>())
16+
.def("check_cthread", []() -> bool {
17+
bool result = false;
18+
std::thread thread([&result]() {
19+
py::gil_scoped_acquire acquire;
20+
result = check_threadstate();
21+
});
22+
thread.join();
23+
return result;
24+
}, py::call_guard<py::gil_scoped_release>());
25+
}
26+
27+
#endif

tests/test_threads.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
pytestmark = pytest.requires_threading
4+
5+
with pytest.suppress(ImportError):
6+
import threading
7+
from pybind11_tests import threads as t
8+
9+
10+
def test_threads():
11+
def pythread_routine():
12+
threading.current_thread()._return = t.check_pythread()
13+
14+
thread = threading.Thread(target=pythread_routine)
15+
thread.start()
16+
thread.join()
17+
assert thread._return
18+
19+
assert t.check_cthread()

0 commit comments

Comments
 (0)