Skip to content

Commit d66be8b

Browse files
committed
feat: add coroutine::CancelHandle
1 parent 81ad2e8 commit d66be8b

14 files changed

+286
-21
lines changed

guide/src/async-await.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,27 @@ where
6969

7070
## Cancellation
7171

72-
*To be implemented*
72+
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].
73+
74+
```rust
75+
# #![allow(dead_code)]
76+
use futures::FutureExt;
77+
use pyo3::prelude::*;
78+
use pyo3::coroutine::CancelHandle;
79+
80+
#[pyfunction]
81+
async fn cancellable(#[pyo3(cancel_handle)]mut cancel: CancelHandle) {
82+
futures::select! {
83+
/* _ = ... => println!("done"), */
84+
_ = cancel.cancelled().fuse() => println!("cancelled"),
85+
}
86+
}
87+
```
7388

7489
## The `Coroutine` type
7590

76-
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
91+
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).
92+
93+
Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;
7794

7895
*The type does not yet have a public constructor until the design is finalized.*

newsfragments/3599.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `coroutine::CancelHandle` to catch coroutine cancellation

pyo3-macros-backend/src/attributes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::{
1111
pub mod kw {
1212
syn::custom_keyword!(annotation);
1313
syn::custom_keyword!(attribute);
14+
syn::custom_keyword!(cancel_handle);
1415
syn::custom_keyword!(dict);
1516
syn::custom_keyword!(extends);
1617
syn::custom_keyword!(freelist);

pyo3-macros-backend/src/method.rs

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub struct FnArg<'a> {
2424
pub attrs: PyFunctionArgPyO3Attributes,
2525
pub is_varargs: bool,
2626
pub is_kwargs: bool,
27+
pub is_cancel_handle: bool,
2728
}
2829

2930
impl<'a> FnArg<'a> {
@@ -44,6 +45,7 @@ impl<'a> FnArg<'a> {
4445
other => return Err(handle_argument_error(other)),
4546
};
4647

48+
let is_cancel_handle = arg_attrs.cancel_handle.is_some();
4749
Ok(FnArg {
4850
name: ident,
4951
ty: &cap.ty,
@@ -53,6 +55,7 @@ impl<'a> FnArg<'a> {
5355
attrs: arg_attrs,
5456
is_varargs: false,
5557
is_kwargs: false,
58+
is_cancel_handle,
5659
})
5760
}
5861
}
@@ -455,9 +458,27 @@ impl<'a> FnSpec<'a> {
455458
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
456459
let func_name = &self.name;
457460

461+
let mut cancel_handle_iter = self
462+
.signature
463+
.arguments
464+
.iter()
465+
.filter(|arg| arg.is_cancel_handle);
466+
let cancel_handle = cancel_handle_iter.next();
467+
if let Some(arg) = cancel_handle {
468+
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
469+
if let Some(arg2) = cancel_handle_iter.next() {
470+
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
471+
}
472+
}
473+
458474
let rust_call = |args: Vec<TokenStream>| {
459475
let mut call = quote! { function(#self_arg #(#args),*) };
460476
if self.asyncness.is_some() {
477+
let throw_callback = if cancel_handle.is_some() {
478+
quote! { Some(__throw_callback) }
479+
} else {
480+
quote! { None }
481+
};
461482
let python_name = &self.python_name;
462483
let qualname_prefix = match cls {
463484
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
@@ -468,9 +489,17 @@ impl<'a> FnSpec<'a> {
468489
_pyo3::impl_::coroutine::new_coroutine(
469490
_pyo3::intern!(py, stringify!(#python_name)),
470491
#qualname_prefix,
471-
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
492+
#throw_callback,
493+
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
472494
)
473495
}};
496+
if cancel_handle.is_some() {
497+
call = quote! {{
498+
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
499+
let __throw_callback = __cancel_handle.throw_callback();
500+
#call
501+
}};
502+
}
474503
}
475504
quotes::map_result_into_ptr(quotes::ok_wrap(call))
476505
};
@@ -483,12 +512,21 @@ impl<'a> FnSpec<'a> {
483512

484513
Ok(match self.convention {
485514
CallingConvention::Noargs => {
486-
let call = if !self.signature.arguments.is_empty() {
487-
// Only `py` arg can be here
488-
rust_call(vec![quote!(py)])
489-
} else {
490-
rust_call(vec![])
491-
};
515+
let args = self
516+
.signature
517+
.arguments
518+
.iter()
519+
.map(|arg| {
520+
if arg.py {
521+
quote!(py)
522+
} else if arg.is_cancel_handle {
523+
quote!(__cancel_handle)
524+
} else {
525+
unreachable!()
526+
}
527+
})
528+
.collect();
529+
let call = rust_call(args);
492530

493531
quote! {
494532
unsafe fn #ident<'py>(

pyo3-macros-backend/src/params.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ fn impl_arg_param(
155155
return Ok(quote! { py });
156156
}
157157

158+
if arg.is_cancel_handle {
159+
return Ok(quote! { __cancel_handle });
160+
}
161+
158162
let name = arg.name;
159163
let name_str = name.to_string();
160164

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
2323
#[derive(Clone, Debug)]
2424
pub struct PyFunctionArgPyO3Attributes {
2525
pub from_py_with: Option<FromPyWithAttribute>,
26+
pub cancel_handle: Option<attributes::kw::cancel_handle>,
2627
}
2728

2829
enum PyFunctionArgPyO3Attribute {
2930
FromPyWith(FromPyWithAttribute),
31+
CancelHandle(attributes::kw::cancel_handle),
3032
}
3133

3234
impl Parse for PyFunctionArgPyO3Attribute {
3335
fn parse(input: ParseStream<'_>) -> Result<Self> {
3436
let lookahead = input.lookahead1();
35-
if lookahead.peek(attributes::kw::from_py_with) {
37+
if lookahead.peek(attributes::kw::cancel_handle) {
38+
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
39+
} else if lookahead.peek(attributes::kw::from_py_with) {
3640
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
3741
} else {
3842
Err(lookahead.error())
@@ -43,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
4347
impl PyFunctionArgPyO3Attributes {
4448
/// Parses #[pyo3(from_python_with = "func")]
4549
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
46-
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
50+
let mut attributes = PyFunctionArgPyO3Attributes {
51+
from_py_with: None,
52+
cancel_handle: None,
53+
};
4754
take_attributes(attrs, |attr| {
4855
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
4956
for attr in pyo3_attrs {
@@ -55,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
5562
);
5663
attributes.from_py_with = Some(from_py_with);
5764
}
65+
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
66+
ensure_spanned!(
67+
attributes.cancel_handle.is_none(),
68+
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
69+
);
70+
attributes.cancel_handle = Some(cancel_handle);
71+
}
5872
}
73+
ensure_spanned!(
74+
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
75+
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
76+
);
5977
}
6078
Ok(true)
6179
} else {

pyo3-macros-backend/src/pyfunction/signature.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
361361
// Otherwise try next argument.
362362
continue;
363363
}
364+
if fn_arg.is_cancel_handle {
365+
// If the user incorrectly tried to include cancel: CoroutineCancel in the
366+
// signature, give a useful error as a hint.
367+
ensure_spanned!(
368+
name != fn_arg.name,
369+
name.span() => "`cancel_handle` argument must not be part of the signature"
370+
);
371+
// Otherwise try next argument.
372+
continue;
373+
}
364374

365375
ensure_spanned!(
366376
name == fn_arg.name,
@@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
411421
}
412422

413423
// Ensure no non-py arguments remain
414-
if let Some(arg) = args_iter.find(|arg| !arg.py) {
424+
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
415425
bail_spanned!(
416426
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
417427
);
@@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
429439
let mut python_signature = PythonSignature::default();
430440
for arg in &arguments {
431441
// Python<'_> arguments don't show in Python signature
432-
if arg.py {
442+
if arg.py || arg.is_cancel_handle {
433443
continue;
434444
}
435445

src/coroutine.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ use crate::{
2121
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
2222
};
2323

24+
pub(crate) mod cancel;
2425
mod waker;
2526

27+
use crate::coroutine::cancel::ThrowCallback;
28+
pub use cancel::CancelHandle;
29+
2630
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
2731

2832
type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
@@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
3236
pub struct Coroutine {
3337
name: Option<Py<PyString>>,
3438
qualname_prefix: Option<&'static str>,
39+
throw_callback: Option<ThrowCallback>,
3540
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
3641
waker: Option<Arc<AsyncioWaker>>,
3742
}
@@ -46,6 +51,7 @@ impl Coroutine {
4651
pub(crate) fn new<F, T, E>(
4752
name: Option<Py<PyString>>,
4853
qualname_prefix: Option<&'static str>,
54+
throw_callback: Option<ThrowCallback>,
4955
future: F,
5056
) -> Self
5157
where
@@ -61,6 +67,7 @@ impl Coroutine {
6167
Self {
6268
name,
6369
qualname_prefix,
70+
throw_callback,
6471
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
6572
waker: None,
6673
}
@@ -77,9 +84,13 @@ impl Coroutine {
7784
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
7885
};
7986
// reraise thrown exception it
80-
if let Some(exc) = throw {
81-
self.close();
82-
return Err(PyErr::from_value(exc.as_ref(py)));
87+
match (throw, &self.throw_callback) {
88+
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
89+
(Some(exc), None) => {
90+
self.close();
91+
return Err(PyErr::from_value(exc.as_ref(py)));
92+
}
93+
_ => {}
8394
}
8495
// create a new waker, or try to reset it in place
8596
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {

src/coroutine/cancel.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use crate::{ffi, Py, PyAny, PyObject};
2+
use futures_util::future::poll_fn;
3+
use futures_util::task::AtomicWaker;
4+
use std::ptr;
5+
use std::ptr::NonNull;
6+
use std::sync::atomic::{AtomicPtr, Ordering};
7+
use std::sync::Arc;
8+
use std::task::{Context, Poll};
9+
10+
#[derive(Debug, Default)]
11+
struct Inner {
12+
exception: AtomicPtr<ffi::PyObject>,
13+
waker: AtomicWaker,
14+
}
15+
16+
/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
17+
///
18+
/// Only the last exception thrown can be retrieved.
19+
#[derive(Debug, Default)]
20+
pub struct CancelHandle(Arc<Inner>);
21+
22+
impl CancelHandle {
23+
/// Create a new `CoroutineCancel`.
24+
pub fn new() -> Self {
25+
Default::default()
26+
}
27+
28+
/// Returns whether the associated coroutine has been cancelled.
29+
pub fn is_cancelled(&self) -> bool {
30+
!self.0.exception.load(Ordering::Relaxed).is_null()
31+
}
32+
33+
/// Poll to retrieve the exception thrown in the associated coroutine.
34+
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
35+
// SAFETY: only valid owned pointer are set in `ThrowCallback::throw`
36+
let take = || unsafe {
37+
// pointer cannot be null because it is checked the line before,
38+
// and the swap is protected by `&mut self`
39+
Py::from_non_null(
40+
NonNull::new(self.0.exception.swap(ptr::null_mut(), Ordering::Relaxed)).unwrap(),
41+
)
42+
};
43+
if self.is_cancelled() {
44+
return Poll::Ready(take());
45+
}
46+
self.0.waker.register(cx.waker());
47+
if self.is_cancelled() {
48+
return Poll::Ready(take());
49+
}
50+
Poll::Pending
51+
}
52+
53+
/// Retrieve the exception thrown in the associated coroutine.
54+
pub async fn cancelled(&mut self) -> PyObject {
55+
poll_fn(|cx| self.poll_cancelled(cx)).await
56+
}
57+
58+
#[doc(hidden)]
59+
pub fn throw_callback(&self) -> ThrowCallback {
60+
ThrowCallback(self.0.clone())
61+
}
62+
}
63+
64+
#[doc(hidden)]
65+
pub struct ThrowCallback(Arc<Inner>);
66+
67+
impl ThrowCallback {
68+
pub(super) fn throw(&self, exc: &PyAny) {
69+
let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed);
70+
// SAFETY: non-null pointers set in `self.0.exceptions` are valid owned pointers
71+
drop(unsafe { PyObject::from_owned_ptr_or_opt(exc.py(), ptr) });
72+
self.0.waker.wake();
73+
}
74+
}

src/impl_/coroutine.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use std::future::Future;
22

3+
use crate::coroutine::cancel::ThrowCallback;
34
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
45

56
pub fn new_coroutine<F, T, E>(
67
name: &PyString,
78
qualname_prefix: Option<&'static str>,
9+
throw_callback: Option<ThrowCallback>,
810
future: F,
911
) -> Coroutine
1012
where
1113
F: Future<Output = Result<T, E>> + Send + 'static,
1214
T: IntoPy<PyObject>,
1315
E: Into<PyErr>,
1416
{
15-
Coroutine::new(Some(name.into()), qualname_prefix, future)
17+
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
1618
}

src/instance.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ impl<T> Py<T> {
10381038
/// # Safety
10391039
/// `ptr` must point to a Python object of type T.
10401040
#[inline]
1041-
unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
1041+
pub(crate) unsafe fn from_non_null(ptr: NonNull<ffi::PyObject>) -> Self {
10421042
Self(ptr, PhantomData)
10431043
}
10441044

0 commit comments

Comments
 (0)