Skip to content

Commit 0aa7a2a

Browse files
committed
refactor(test): execute all #[rustup_macros::unit_test]s within a tokio context
1 parent 3ea4355 commit 0aa7a2a

File tree

2 files changed

+41
-82
lines changed

2 files changed

+41
-82
lines changed

rustup-macros/src/lib.rs

+41-66
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,16 @@ pub fn integration_test(
3939
.into()
4040
}
4141

42-
/// Custom wrapper macro around `#[test]` and `#[tokio::test]` for unit tests.
42+
/// Custom wrapper macro around `#[tokio::test]` for unit tests.
4343
///
4444
/// Calls `rustup::test::before_test()` before the test body, and
4545
/// `rustup::test::after_test()` after, even in the event of an unwinding panic.
46-
/// For async functions calls the async variants of these functions.
46+
///
47+
/// This wrapper makes the underlying test function async even if it's sync in nature.
48+
/// This ensures that a [`tokio`] runtime is always present during tests,
49+
/// making it easier to setup [`tracing`] subscribers
50+
/// (e.g. [`opentelemetry_otlp::OtlpTracePipeline`] always requires a [`tokio`] runtime to be
51+
/// installed).
4752
#[proc_macro_attribute]
4853
pub fn unit_test(
4954
args: proc_macro::TokenStream,
@@ -77,74 +82,44 @@ pub fn unit_test(
7782
.into()
7883
}
7984

80-
// False positive from clippy :/
81-
#[allow(clippy::redundant_clone)]
8285
fn test_inner(mod_path: String, mut input: ItemFn) -> syn::Result<TokenStream> {
83-
if input.sig.asyncness.is_some() {
84-
let before_ident = format!("{}::before_test_async", mod_path);
85-
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
86-
let after_ident = format!("{}::after_test_async", mod_path);
87-
let after_ident = syn::parse_str::<Expr>(&after_ident)?;
88-
89-
let inner = input.block;
90-
let name = input.sig.ident.clone();
91-
let new_block: Block = parse_quote! {
92-
{
93-
#before_ident().await;
94-
// Define a function with same name we can instrument inside the
95-
// tracing enablement logic.
96-
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
97-
async fn #name() { #inner }
98-
// Thunk through a new thread to permit catching the panic
99-
// without grabbing the entire state machine defined by the
100-
// outer test function.
101-
let result = ::std::panic::catch_unwind(||{
102-
let handle = tokio::runtime::Handle::current().clone();
103-
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
104-
});
105-
#after_ident().await;
106-
match result {
107-
Ok(result) => result,
108-
Err(err) => ::std::panic::resume_unwind(err)
109-
}
110-
}
111-
};
86+
// Make the test function async even if it's sync.
87+
input.sig.asyncness.get_or_insert_with(Default::default);
11288

113-
input.block = Box::new(new_block);
89+
let before_ident = format!("{}::before_test_async", mod_path);
90+
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
91+
let after_ident = format!("{}::after_test_async", mod_path);
92+
let after_ident = syn::parse_str::<Expr>(&after_ident)?;
11493

115-
Ok(quote! {
94+
let inner = input.block;
95+
let name = input.sig.ident.clone();
96+
let new_block: Block = parse_quote! {
97+
{
98+
#before_ident().await;
99+
// Define a function with same name we can instrument inside the
100+
// tracing enablement logic.
116101
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
117-
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
118-
#input
119-
})
120-
} else {
121-
let before_ident = format!("{}::before_test", mod_path);
122-
let before_ident = syn::parse_str::<Expr>(&before_ident)?;
123-
let after_ident = format!("{}::after_test", mod_path);
124-
let after_ident = syn::parse_str::<Expr>(&after_ident)?;
125-
126-
let inner = input.block;
127-
let name = input.sig.ident.clone();
128-
let new_block: Block = parse_quote! {
129-
{
130-
#before_ident();
131-
// Define a function with same name we can instrument inside the
132-
// tracing enablement logic.
133-
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
134-
fn #name() { #inner }
135-
let result = ::std::panic::catch_unwind(#name);
136-
#after_ident();
137-
match result {
138-
Ok(result) => result,
139-
Err(err) => ::std::panic::resume_unwind(err)
140-
}
102+
async fn #name() { #inner }
103+
// Thunk through a new thread to permit catching the panic
104+
// without grabbing the entire state machine defined by the
105+
// outer test function.
106+
let result = ::std::panic::catch_unwind(||{
107+
let handle = tokio::runtime::Handle::current().clone();
108+
::std::thread::spawn(move || handle.block_on(#name())).join().unwrap()
109+
});
110+
#after_ident().await;
111+
match result {
112+
Ok(result) => result,
113+
Err(err) => ::std::panic::resume_unwind(err)
141114
}
142-
};
115+
}
116+
};
143117

144-
input.block = Box::new(new_block);
145-
Ok(quote! {
146-
#[::std::prelude::v1::test]
147-
#input
148-
})
149-
}
118+
input.block = Box::new(new_block);
119+
120+
Ok(quote! {
121+
#[cfg_attr(feature = "otel", tracing::instrument(skip_all))]
122+
#[::tokio::test(flavor = "multi_thread", worker_threads = 1)]
123+
#input
124+
})
150125
}

src/test.rs

-16
Original file line numberDiff line numberDiff line change
@@ -277,29 +277,13 @@ static TRACER: Lazy<opentelemetry_sdk::trace::Tracer> = Lazy::new(|| {
277277
tracer
278278
});
279279

280-
pub fn before_test() {
281-
#[cfg(feature = "otel")]
282-
{
283-
Lazy::force(&TRACER);
284-
}
285-
}
286-
287280
pub async fn before_test_async() {
288281
#[cfg(feature = "otel")]
289282
{
290283
Lazy::force(&TRACER);
291284
}
292285
}
293286

294-
pub fn after_test() {
295-
#[cfg(feature = "otel")]
296-
{
297-
let handle = TRACE_RUNTIME.handle();
298-
let _guard = handle.enter();
299-
TRACER.provider().map(|p| p.force_flush());
300-
}
301-
}
302-
303287
pub async fn after_test_async() {
304288
#[cfg(feature = "otel")]
305289
{

0 commit comments

Comments
 (0)