diff --git a/src/lib.rs b/src/lib.rs index b5b5bd0..8142c60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,7 +37,10 @@ extern crate proc_macro; use proc_macro::TokenStream; -use quote::quote; +use proc_macro2::{Ident, Span, TokenStream as TokenStream2}; +use quote::{quote, ToTokens}; +use syn::parse::{self, Parse, ParseStream}; +use syn::Token; /// Add context to errors from a function. /// @@ -55,32 +58,67 @@ use quote::quote; /// })().map_err(|err| err.context("context")) /// } /// ``` +/// +/// Sometimes you will receive borrowck errors, especially when returning references. These can +/// often be fixed by setting the `move` option of the attribute macro. For example: +/// +/// ``` +/// use fn_error_context::context; +/// +/// #[context(move, "context")] +/// fn returns_reference(val: &mut u32) -> anyhow::Result<&mut u32> { +/// Ok(&mut *val) +/// } +/// ``` #[proc_macro_attribute] pub fn context(args: TokenStream, input: TokenStream) -> TokenStream { - let args: proc_macro2::TokenStream = args.into(); + let Args(move_token, format_args) = syn::parse_macro_input!(args); let mut input = syn::parse_macro_input!(input as syn::ItemFn); let body = &input.block; let return_ty = &input.sig.output; - if input.sig.asyncness.is_some() { - match return_ty { + let err = Ident::new("err", Span::mixed_site()); + let new_body = if input.sig.asyncness.is_some() { + let return_ty = match return_ty { syn::ReturnType::Default => { - return syn::Error::new_spanned(return_ty, "function should return Result") + return syn::Error::new_spanned(input, "function should return Result") .to_compile_error() - .into() - } - syn::ReturnType::Type(_, return_ty) => { - input.block.stmts = syn::parse_quote!( - let result: #return_ty = async { #body }.await; - result.map_err(|err| err.context(format!(#args)).into()) - ); + .into(); } + syn::ReturnType::Type(_, return_ty) => return_ty, + }; + let result = Ident::new("result", Span::mixed_site()); + quote! { + let #result: #return_ty = async #move_token { #body }.await; + #result.map_err(|#err| #err.context(format!(#format_args)).into()) } } else { - input.block.stmts = syn::parse_quote!( - (|| #return_ty #body)().map_err(|err| err.context(format!(#args)).into()) - ); - } + let force_fn_once = Ident::new("force_fn_once", Span::mixed_site()); + quote! { + // Moving a non-`Copy` value into the closure tells borrowck to always treat the closure + // as a `FnOnce`, preventing some borrowing errors. + let #force_fn_once = ::core::iter::empty::<()>(); + (#move_token || #return_ty { + ::core::mem::drop(#force_fn_once); + #body + })().map_err(|#err| #err.context(format!(#format_args)).into()) + } + }; + input.block.stmts = vec![syn::Stmt::Expr(syn::Expr::Verbatim(new_body))]; - quote!(#input).into() + input.into_token_stream().into() +} + +struct Args(Option, TokenStream2); +impl Parse for Args { + fn parse(input: ParseStream<'_>) -> parse::Result { + let move_token = if input.peek(Token![move]) { + let token = input.parse()?; + input.parse::()?; + Some(token) + } else { + None + }; + Ok(Self(move_token, input.parse()?)) + } } diff --git a/tests/async_borrowing.rs b/tests/async_borrowing.rs new file mode 100644 index 0000000..7148c66 --- /dev/null +++ b/tests/async_borrowing.rs @@ -0,0 +1,8 @@ +use fn_error_context::context; + +#[context("context")] +async fn borrows(x: &mut u32) -> anyhow::Result<&mut u32> { + Ok(x) +} + +fn main() {} diff --git a/tests/async_move.rs b/tests/async_move.rs new file mode 100644 index 0000000..3e6eaf7 --- /dev/null +++ b/tests/async_move.rs @@ -0,0 +1,8 @@ +use fn_error_context::context; + +#[context(move, "context")] +async fn borrows(val: &mut u32) -> anyhow::Result<&u32> { + Ok(&*val) +} + +fn main() {} diff --git a/tests/async_no_move.rs b/tests/async_no_move.rs new file mode 100644 index 0000000..7d2c40c --- /dev/null +++ b/tests/async_no_move.rs @@ -0,0 +1,9 @@ +use fn_error_context::context; + +#[context("{}", context.as_ref())] +async fn no_move(context: impl AsRef) -> anyhow::Result<()> { + context.as_ref(); + Ok(()) +} + +fn main() {} diff --git a/tests/async_without_return.stderr b/tests/async_without_return.stderr index fc35392..3f85983 100644 --- a/tests/async_without_return.stderr +++ b/tests/async_without_return.stderr @@ -1,7 +1,6 @@ error: function should return Result - --> $DIR/async_without_return.rs:3:1 + --> $DIR/async_without_return.rs:4:1 | -3 | #[context("context")] - | ^^^^^^^^^^^^^^^^^^^^^ - | - = note: this error originates in an attribute macro (in Nightly builds, run with -Z macro-backtrace for more info) +4 | / async fn async_something() { +5 | | } + | |_^ diff --git a/tests/move.rs b/tests/move.rs new file mode 100644 index 0000000..549983b --- /dev/null +++ b/tests/move.rs @@ -0,0 +1,8 @@ +use fn_error_context::context; + +#[context(move, "context")] +fn foo(x: &mut u32) -> anyhow::Result<&u32> { + Ok(&*x) +} + +fn main() {} diff --git a/tests/no_move.rs b/tests/no_move.rs new file mode 100644 index 0000000..b2d3f50 --- /dev/null +++ b/tests/no_move.rs @@ -0,0 +1,9 @@ +use fn_error_context::context; + +#[context("{}", context.as_ref())] +fn no_move(context: impl AsRef) -> anyhow::Result<()> { + context.as_ref(); + Ok(()) +} + +fn main() {} diff --git a/tests/tests.rs b/tests/tests.rs index 48c0fac..c9b0275 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -17,4 +17,9 @@ fn tests() { tests.pass("tests/fmt_named_arg.rs"); tests.compile_fail("tests/async_without_return.rs"); tests.compile_fail("tests/preserve_lint.rs"); + tests.pass("tests/async_borrowing.rs"); + tests.pass("tests/no_move.rs"); + tests.pass("tests/async_no_move.rs"); + tests.pass("tests/move.rs"); + tests.pass("tests/async_move.rs"); }