@@ -39,11 +39,16 @@ pub fn integration_test(
39
39
. into ( )
40
40
}
41
41
42
- /// Custom wrapper macro around `#[test]` and `#[ tokio::test]` for unit tests.
42
+ /// Custom wrapper macro around `#[tokio::test]` for unit tests.
43
43
///
44
44
/// Calls `rustup::test::before_test()` before the test body, and
45
45
/// `rustup::test::after_test()` after, even in the event of an unwinding panic.
46
- /// For async functions calls the async variants of these functions.
46
+ ///
47
+ /// This wrapper makes the underlying test function async even if it's sync in nature.
48
+ /// This ensures that a [`tokio`] runtime is always present during tests,
49
+ /// making it easier to setup [`tracing`] subscribers
50
+ /// (e.g. [`opentelemetry_otlp::OtlpTracePipeline`] always requires a [`tokio`] runtime to be
51
+ /// installed).
47
52
#[ proc_macro_attribute]
48
53
pub fn unit_test (
49
54
args : proc_macro:: TokenStream ,
@@ -77,74 +82,44 @@ pub fn unit_test(
77
82
. into ( )
78
83
}
79
84
80
- // False positive from clippy :/
81
- #[ allow( clippy:: redundant_clone) ]
82
85
fn test_inner ( mod_path : String , mut input : ItemFn ) -> syn:: Result < TokenStream > {
83
- if input. sig . asyncness . is_some ( ) {
84
- let before_ident = format ! ( "{}::before_test_async" , mod_path) ;
85
- let before_ident = syn:: parse_str :: < Expr > ( & before_ident) ?;
86
- let after_ident = format ! ( "{}::after_test_async" , mod_path) ;
87
- let after_ident = syn:: parse_str :: < Expr > ( & after_ident) ?;
88
-
89
- let inner = input. block ;
90
- let name = input. sig . ident . clone ( ) ;
91
- let new_block: Block = parse_quote ! {
92
- {
93
- #before_ident( ) . await ;
94
- // Define a function with same name we can instrument inside the
95
- // tracing enablement logic.
96
- #[ cfg_attr( feature = "otel" , tracing:: instrument( skip_all) ) ]
97
- async fn #name( ) { #inner }
98
- // Thunk through a new thread to permit catching the panic
99
- // without grabbing the entire state machine defined by the
100
- // outer test function.
101
- let result = :: std:: panic:: catch_unwind( ||{
102
- let handle = tokio:: runtime:: Handle :: current( ) . clone( ) ;
103
- :: std:: thread:: spawn( move || handle. block_on( #name( ) ) ) . join( ) . unwrap( )
104
- } ) ;
105
- #after_ident( ) . await ;
106
- match result {
107
- Ok ( result) => result,
108
- Err ( err) => :: std:: panic:: resume_unwind( err)
109
- }
110
- }
111
- } ;
86
+ // Make the test function async even if it's sync.
87
+ input. sig . asyncness . get_or_insert_with ( Default :: default) ;
112
88
113
- input. block = Box :: new ( new_block) ;
89
+ let before_ident = format ! ( "{}::before_test_async" , mod_path) ;
90
+ let before_ident = syn:: parse_str :: < Expr > ( & before_ident) ?;
91
+ let after_ident = format ! ( "{}::after_test_async" , mod_path) ;
92
+ let after_ident = syn:: parse_str :: < Expr > ( & after_ident) ?;
114
93
115
- Ok ( quote ! {
94
+ let inner = input. block ;
95
+ let name = input. sig . ident . clone ( ) ;
96
+ let new_block: Block = parse_quote ! {
97
+ {
98
+ #before_ident( ) . await ;
99
+ // Define a function with same name we can instrument inside the
100
+ // tracing enablement logic.
116
101
#[ cfg_attr( feature = "otel" , tracing:: instrument( skip_all) ) ]
117
- #[ :: tokio:: test( flavor = "multi_thread" , worker_threads = 1 ) ]
118
- #input
119
- } )
120
- } else {
121
- let before_ident = format ! ( "{}::before_test" , mod_path) ;
122
- let before_ident = syn:: parse_str :: < Expr > ( & before_ident) ?;
123
- let after_ident = format ! ( "{}::after_test" , mod_path) ;
124
- let after_ident = syn:: parse_str :: < Expr > ( & after_ident) ?;
125
-
126
- let inner = input. block ;
127
- let name = input. sig . ident . clone ( ) ;
128
- let new_block: Block = parse_quote ! {
129
- {
130
- #before_ident( ) ;
131
- // Define a function with same name we can instrument inside the
132
- // tracing enablement logic.
133
- #[ cfg_attr( feature = "otel" , tracing:: instrument( skip_all) ) ]
134
- fn #name( ) { #inner }
135
- let result = :: std:: panic:: catch_unwind( #name) ;
136
- #after_ident( ) ;
137
- match result {
138
- Ok ( result) => result,
139
- Err ( err) => :: std:: panic:: resume_unwind( err)
140
- }
102
+ async fn #name( ) { #inner }
103
+ // Thunk through a new thread to permit catching the panic
104
+ // without grabbing the entire state machine defined by the
105
+ // outer test function.
106
+ let result = :: std:: panic:: catch_unwind( ||{
107
+ let handle = tokio:: runtime:: Handle :: current( ) . clone( ) ;
108
+ :: std:: thread:: spawn( move || handle. block_on( #name( ) ) ) . join( ) . unwrap( )
109
+ } ) ;
110
+ #after_ident( ) . await ;
111
+ match result {
112
+ Ok ( result) => result,
113
+ Err ( err) => :: std:: panic:: resume_unwind( err)
141
114
}
142
- } ;
115
+ }
116
+ } ;
143
117
144
- input. block = Box :: new ( new_block) ;
145
- Ok ( quote ! {
146
- #[ :: std:: prelude:: v1:: test]
147
- #input
148
- } )
149
- }
118
+ input. block = Box :: new ( new_block) ;
119
+
120
+ Ok ( quote ! {
121
+ #[ cfg_attr( feature = "otel" , tracing:: instrument( skip_all) ) ]
122
+ #[ :: tokio:: test( flavor = "multi_thread" , worker_threads = 1 ) ]
123
+ #input
124
+ } )
150
125
}
0 commit comments