Skip to content

Commit 40c0b9f

Browse files
committed
feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions
1 parent b11174e commit 40c0b9f

18 files changed

+432
-158
lines changed

newsfragments/3610.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions

pyo3-macros-backend/src/attributes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use syn::{
99
};
1010

1111
pub mod kw {
12+
syn::custom_keyword!(allow_threads);
1213
syn::custom_keyword!(annotation);
1314
syn::custom_keyword!(attribute);
1415
syn::custom_keyword!(cancel_handle);

pyo3-macros-backend/src/method.rs

Lines changed: 49 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
66

77
use crate::utils::Ctx;
88
use crate::{
9+
attributes,
910
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
1011
deprecations::{Deprecation, Deprecations},
1112
params::{impl_arg_params, Holders},
@@ -379,6 +380,7 @@ pub struct FnSpec<'a> {
379380
pub asyncness: Option<syn::Token![async]>,
380381
pub unsafety: Option<syn::Token![unsafe]>,
381382
pub deprecations: Deprecations<'a>,
383+
pub allow_threads: Option<attributes::kw::allow_threads>,
382384
}
383385

384386
pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
@@ -416,6 +418,7 @@ impl<'a> FnSpec<'a> {
416418
text_signature,
417419
name,
418420
signature,
421+
allow_threads,
419422
..
420423
} = options;
421424

@@ -461,6 +464,7 @@ impl<'a> FnSpec<'a> {
461464
asyncness: sig.asyncness,
462465
unsafety: sig.unsafety,
463466
deprecations,
467+
allow_threads,
464468
})
465469
}
466470

@@ -603,6 +607,21 @@ impl<'a> FnSpec<'a> {
603607
bail_spanned!(name.span() => "`cancel_handle` may only be specified once");
604608
}
605609
}
610+
if let Some(FnArg::Py(py_arg)) = self
611+
.signature
612+
.arguments
613+
.iter()
614+
.find(|arg| matches!(arg, FnArg::Py(_)))
615+
{
616+
ensure_spanned!(
617+
self.asyncness.is_none(),
618+
py_arg.ty.span() => "GIL token cannot be passed to async function"
619+
);
620+
ensure_spanned!(
621+
self.allow_threads.is_none(),
622+
py_arg.ty.span() => "GIL cannot be held in function annotated with `allow_threads`"
623+
);
624+
}
606625

607626
if self.asyncness.is_some() {
608627
ensure_spanned!(
@@ -612,8 +631,21 @@ impl<'a> FnSpec<'a> {
612631
}
613632

614633
let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
615-
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
616-
634+
let allow_threads = self.allow_threads.is_some();
635+
let mut self_arg = || {
636+
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
637+
if self_arg.is_empty() {
638+
self_arg
639+
} else {
640+
let self_checker = holders.push_gil_refs_checker(self_arg.span());
641+
quote! {
642+
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
643+
}
644+
}
645+
};
646+
let arg_names = (0..args.len())
647+
.map(|i| format_ident!("arg_{}", i))
648+
.collect::<Vec<_>>();
617649
let call = if self.asyncness.is_some() {
618650
let throw_callback = if cancel_handle.is_some() {
619651
quote! { Some(__throw_callback) }
@@ -625,9 +657,6 @@ impl<'a> FnSpec<'a> {
625657
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
626658
None => quote!(None),
627659
};
628-
let arg_names = (0..args.len())
629-
.map(|i| format_ident!("arg_{}", i))
630-
.collect::<Vec<_>>();
631660
let future = match self.tp {
632661
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
633662
quote! {{
@@ -645,18 +674,7 @@ impl<'a> FnSpec<'a> {
645674
}
646675
_ => {
647676
let self_arg = self_arg();
648-
if self_arg.is_empty() {
649-
quote! { function(#(#args),*) }
650-
} else {
651-
let self_checker = holders.push_gil_refs_checker(self_arg.span());
652-
quote! {
653-
function(
654-
// NB #self_arg includes a comma, so none inserted here
655-
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
656-
#(#args),*
657-
)
658-
}
659-
}
677+
quote!(function(#self_arg #(#args),*))
660678
}
661679
};
662680
let mut call = quote! {{
@@ -665,6 +683,7 @@ impl<'a> FnSpec<'a> {
665683
#pyo3_path::intern!(py, stringify!(#python_name)),
666684
#qualname_prefix,
667685
#throw_callback,
686+
#allow_threads,
668687
async move { #pyo3_path::impl_::wrap::OkWrap::wrap(future.await) },
669688
)
670689
}};
@@ -676,20 +695,21 @@ impl<'a> FnSpec<'a> {
676695
}};
677696
}
678697
call
679-
} else {
698+
} else if allow_threads {
680699
let self_arg = self_arg();
681-
if self_arg.is_empty() {
682-
quote! { function(#(#args),*) }
700+
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
701+
(quote!(), quote!())
683702
} else {
684-
let self_checker = holders.push_gil_refs_checker(self_arg.span());
685-
quote! {
686-
function(
687-
// NB #self_arg includes a comma, so none inserted here
688-
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
689-
#(#args),*
690-
)
691-
}
692-
}
703+
(quote!(__self,), quote! { let (__self,) = (#self_arg); })
704+
};
705+
quote! {{
706+
#self_arg_decl
707+
#(let #arg_names = #args;)*
708+
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
709+
}}
710+
} else {
711+
let self_arg = self_arg();
712+
quote!(function(#self_arg #(#args),*))
693713
};
694714
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx)
695715
};

pyo3-macros-backend/src/pyclass.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,7 @@ fn complex_enum_struct_variant_new<'a>(
11741174
asyncness: None,
11751175
unsafety: None,
11761176
deprecations: Deprecations::new(ctx),
1177+
allow_threads: None,
11771178
};
11781179

11791180
crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
@@ -1199,6 +1200,7 @@ fn complex_enum_variant_field_getter<'a>(
11991200
asyncness: None,
12001201
unsafety: None,
12011202
deprecations: Deprecations::new(ctx),
1203+
allow_threads: None,
12021204
};
12031205

12041206
let property_type = crate::pymethod::PropertyType::Function {

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
9191
pub signature: Option<SignatureAttribute>,
9292
pub text_signature: Option<TextSignatureAttribute>,
9393
pub krate: Option<CrateAttribute>,
94+
pub allow_threads: Option<attributes::kw::allow_threads>,
9495
}
9596

9697
impl Parse for PyFunctionOptions {
@@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {
99100

100101
while !input.is_empty() {
101102
let lookahead = input.lookahead1();
102-
if lookahead.peek(attributes::kw::name)
103+
if lookahead.peek(attributes::kw::allow_threads)
104+
|| lookahead.peek(attributes::kw::name)
103105
|| lookahead.peek(attributes::kw::pass_module)
104106
|| lookahead.peek(attributes::kw::signature)
105107
|| lookahead.peek(attributes::kw::text_signature)
@@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
121123
}
122124

123125
pub enum PyFunctionOption {
126+
AllowThreads(attributes::kw::allow_threads),
124127
Name(NameAttribute),
125128
PassModule(attributes::kw::pass_module),
126129
Signature(SignatureAttribute),
@@ -131,7 +134,9 @@ pub enum PyFunctionOption {
131134
impl Parse for PyFunctionOption {
132135
fn parse(input: ParseStream<'_>) -> Result<Self> {
133136
let lookahead = input.lookahead1();
134-
if lookahead.peek(attributes::kw::name) {
137+
if lookahead.peek(attributes::kw::allow_threads) {
138+
input.parse().map(PyFunctionOption::AllowThreads)
139+
} else if lookahead.peek(attributes::kw::name) {
135140
input.parse().map(PyFunctionOption::Name)
136141
} else if lookahead.peek(attributes::kw::pass_module) {
137142
input.parse().map(PyFunctionOption::PassModule)
@@ -171,6 +176,7 @@ impl PyFunctionOptions {
171176
}
172177
for attr in attrs {
173178
match attr {
179+
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
174180
PyFunctionOption::Name(name) => set_option!(name),
175181
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
176182
PyFunctionOption::Signature(signature) => set_option!(signature),
@@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
198204
) -> syn::Result<TokenStream> {
199205
check_generic(&func.sig)?;
200206
let PyFunctionOptions {
207+
allow_threads,
201208
pass_module,
202209
name,
203210
signature,
@@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
247254
python_name,
248255
signature,
249256
text_signature,
257+
allow_threads,
250258
asyncness: func.sig.asyncness,
251259
unsafety: func.sig.unsafety,
252260
deprecations: Deprecations::new(ctx),

pyo3-macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
121121
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
122122
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
123123
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
124+
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
124125
///
125126
/// For more on exposing functions see the [function section of the guide][1].
126127
///

src/coroutine.rs

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,54 @@ use crate::{
2121
pub(crate) mod cancel;
2222
mod waker;
2323

24+
use crate::marker::Ungil;
2425
pub use cancel::CancelHandle;
2526

2627
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
2728

29+
trait CoroutineFuture: Send {
30+
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>>;
31+
}
32+
33+
impl<F, T, E> CoroutineFuture for F
34+
where
35+
F: Future<Output = Result<T, E>> + Send,
36+
T: IntoPy<PyObject> + Send,
37+
E: Into<PyErr> + Send,
38+
{
39+
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
40+
self.poll(&mut Context::from_waker(waker))
41+
.map_ok(|obj| obj.into_py(py))
42+
.map_err(Into::into)
43+
}
44+
}
45+
46+
struct AllowThreads<F> {
47+
future: F,
48+
}
49+
50+
impl<F, T, E> CoroutineFuture for AllowThreads<F>
51+
where
52+
F: Future<Output = Result<T, E>> + Send + Ungil,
53+
T: IntoPy<PyObject> + Send + Ungil,
54+
E: Into<PyErr> + Send + Ungil,
55+
{
56+
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
57+
// SAFETY: future field is pinned when self is
58+
let future = unsafe { self.map_unchecked_mut(|a| &mut a.future) };
59+
py.allow_threads(|| future.poll(&mut Context::from_waker(waker)))
60+
.map_ok(|obj| obj.into_py(py))
61+
.map_err(Into::into)
62+
}
63+
}
64+
2865
/// Python coroutine wrapping a [`Future`].
2966
#[pyclass(crate = "crate")]
3067
pub struct Coroutine {
3168
name: Option<Py<PyString>>,
3269
qualname_prefix: Option<&'static str>,
3370
throw_callback: Option<ThrowCallback>,
34-
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
71+
future: Option<Pin<Box<dyn CoroutineFuture>>>,
3572
waker: Option<Arc<AsyncioWaker>>,
3673
}
3774

@@ -46,23 +83,23 @@ impl Coroutine {
4683
name: Option<Py<PyString>>,
4784
qualname_prefix: Option<&'static str>,
4885
throw_callback: Option<ThrowCallback>,
86+
allow_threads: bool,
4987
future: F,
5088
) -> Self
5189
where
52-
F: Future<Output = Result<T, E>> + Send + 'static,
53-
T: IntoPy<PyObject>,
54-
E: Into<PyErr>,
90+
F: Future<Output = Result<T, E>> + Send + Ungil + 'static,
91+
T: IntoPy<PyObject> + Send + Ungil,
92+
E: Into<PyErr> + Send + Ungil,
5593
{
56-
let wrap = async move {
57-
let obj = future.await.map_err(Into::into)?;
58-
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
59-
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
60-
};
6194
Self {
6295
name,
6396
qualname_prefix,
6497
throw_callback,
65-
future: Some(Box::pin(wrap)),
98+
future: Some(if allow_threads {
99+
Box::pin(AllowThreads { future })
100+
} else {
101+
Box::pin(future)
102+
}),
66103
waker: None,
67104
}
68105
}
@@ -88,10 +125,10 @@ impl Coroutine {
88125
} else {
89126
self.waker = Some(Arc::new(AsyncioWaker::new()));
90127
}
91-
let waker = Waker::from(self.waker.clone().unwrap());
92-
// poll the Rust future and forward its results if ready
128+
// poll the future and forward its results if ready
93129
// polling is UnwindSafe because the future is dropped in case of panic
94-
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
130+
let waker = Waker::from(self.waker.clone().unwrap());
131+
let poll = || future_rs.as_mut().poll(py, &waker);
95132
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
96133
Ok(Poll::Ready(res)) => {
97134
self.close();

0 commit comments

Comments
 (0)