@@ -13,6 +13,7 @@ use async_lock::RwLock;
13
13
use async_trait:: async_trait;
14
14
use std:: sync:: Arc ;
15
15
use std:: time:: Duration ;
16
+ use typespec_client_core:: date:: OffsetDateTime ;
16
17
use typespec_client_core:: http:: { Context , Request } ;
17
18
18
19
/// Authentication policy for a bearer token.
@@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy {
23
24
access_token : Arc < RwLock < Option < AccessToken > > > ,
24
25
}
25
26
26
- /// Default timeout in seconds before refreshing a new token.
27
- const DEFAULT_REFRESH_TIME : Duration = Duration :: from_secs ( 120 ) ;
28
-
29
27
impl BearerTokenCredentialPolicy {
30
28
pub fn new < A , B > ( credential : Arc < dyn TokenCredential > , scopes : A ) -> Self
31
29
where
@@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy {
63
61
) -> PolicyResult {
64
62
let access_token = self . access_token . read ( ) . await ;
65
63
66
- if let Some ( token) = & ( * access_token) {
67
- if token. is_expired ( Some ( DEFAULT_REFRESH_TIME ) ) {
64
+ match access_token. as_ref ( ) {
65
+ None => {
66
+ // cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so
67
+ drop ( access_token) ;
68
+ let mut access_token = self . access_token . write ( ) . await ;
69
+ if access_token. is_none ( ) {
70
+ * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
71
+ }
72
+ }
73
+ Some ( token) if should_refresh ( & token. expires_on ) => {
74
+ // token is expired or within its refresh window. Upgrade the lock and
75
+ // acquire a new token, provided another thread hasn't already done so
76
+ let expires_on = token. expires_on ;
68
77
drop ( access_token) ;
69
78
let mut access_token = self . access_token . write ( ) . await ;
70
- * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
79
+ // access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
80
+ if access_token. is_none ( ) || access_token. as_ref ( ) . unwrap ( ) . expires_on == expires_on
81
+ {
82
+ match self . credential . get_token ( & self . scopes ( ) ) . await {
83
+ Ok ( new_token) => {
84
+ * access_token = Some ( new_token) ;
85
+ }
86
+ Err ( e)
87
+ if access_token. is_none ( )
88
+ || expires_on <= OffsetDateTime :: now_utc ( ) =>
89
+ {
90
+ // propagate this error because we can't proceed without a new token
91
+ return Err ( e) ;
92
+ }
93
+ Err ( _) => {
94
+ // ignore this error because the cached token is still valid
95
+ }
96
+ }
97
+ }
98
+ }
99
+ Some ( _) => {
100
+ // do nothing; cached token is valid and not within its refresh window
71
101
}
72
- } else {
73
- drop ( access_token) ;
74
- let mut access_token = self . access_token . write ( ) . await ;
75
- * access_token = Some ( self . credential . get_token ( & self . scopes ( ) ) . await ?) ;
76
102
}
77
103
78
104
let access_token = self . access_token ( ) . await . ok_or_else ( || {
@@ -86,3 +112,161 @@ impl Policy for BearerTokenCredentialPolicy {
86
112
next[ 0 ] . send ( ctx, request, & next[ 1 ..] ) . await
87
113
}
88
114
}
115
+
116
+ fn should_refresh ( expires_on : & OffsetDateTime ) -> bool {
117
+ * expires_on <= OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 300 )
118
+ }
119
+
120
+ #[ cfg( test) ]
121
+ mod tests {
122
+ use super :: * ;
123
+ use crate :: {
124
+ credentials:: { Secret , TokenCredential } ,
125
+ http:: {
126
+ headers:: { Headers , AUTHORIZATION } ,
127
+ policies:: Policy ,
128
+ Request , Response , StatusCode ,
129
+ } ,
130
+ Bytes , Result ,
131
+ } ;
132
+ use async_trait:: async_trait;
133
+ use azure_core_test:: http:: MockHttpClient ;
134
+ use futures:: FutureExt ;
135
+ use std:: sync:: {
136
+ atomic:: { AtomicUsize , Ordering } ,
137
+ Arc ,
138
+ } ;
139
+ use std:: time:: Duration ;
140
+ use time:: OffsetDateTime ;
141
+ use typespec_client_core:: http:: { policies:: TransportPolicy , Method , TransportOptions } ;
142
+
143
+ #[ derive( Debug , Clone ) ]
144
+ struct MockCredential {
145
+ calls : Arc < AtomicUsize > ,
146
+ tokens : Arc < [ AccessToken ] > ,
147
+ }
148
+
149
+ impl MockCredential {
150
+ fn new ( tokens : & [ AccessToken ] ) -> Self {
151
+ Self {
152
+ calls : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
153
+ tokens : tokens. into ( ) ,
154
+ }
155
+ }
156
+
157
+ fn get_token_calls ( & self ) -> usize {
158
+ self . calls . load ( Ordering :: SeqCst )
159
+ }
160
+ }
161
+
162
+ // ensure the number of get_token() calls matches the number of tokens
163
+ // in a test case i.e., that the policy called get_token() as expected
164
+ impl Drop for MockCredential {
165
+ fn drop ( & mut self ) {
166
+ if !self . tokens . is_empty ( ) {
167
+ assert_eq ! ( self . tokens. len( ) , self . calls. load( Ordering :: SeqCst ) ) ;
168
+ }
169
+ }
170
+ }
171
+
172
+ #[ cfg_attr( target_arch = "wasm32" , async_trait( ?Send ) ) ]
173
+ #[ cfg_attr( not( target_arch = "wasm32" ) , async_trait) ]
174
+ impl TokenCredential for MockCredential {
175
+ async fn get_token ( & self , _scopes : & [ & str ] ) -> Result < AccessToken > {
176
+ let i = self . calls . fetch_add ( 1 , Ordering :: SeqCst ) ;
177
+ self . tokens
178
+ . get ( i)
179
+ . ok_or_else ( || Error :: message ( ErrorKind :: Credential , "no more mock tokens" ) )
180
+ . cloned ( )
181
+ }
182
+ }
183
+
184
+ #[ tokio:: test]
185
+ async fn authn_error ( ) {
186
+ // this mock's get_token() will return an error because it has no tokens
187
+ let credential = MockCredential :: new ( & [ ] ) ;
188
+ let policy = BearerTokenCredentialPolicy :: new ( Arc :: new ( credential) , [ "scope" ] ) ;
189
+ let client = MockHttpClient :: new ( |_| panic ! ( "expected an error from get_token" ) ) ;
190
+ let transport = Arc :: new ( TransportPolicy :: new ( TransportOptions :: new ( Arc :: new (
191
+ client,
192
+ ) ) ) ) ;
193
+ let mut req = Request :: new ( "https://localhost" . parse ( ) . unwrap ( ) , Method :: Get ) ;
194
+
195
+ let err = policy
196
+ . send ( & Context :: default ( ) , & mut req, & [ transport. clone ( ) ] )
197
+ . await
198
+ . expect_err ( "request should fail" ) ;
199
+
200
+ assert_eq ! ( ErrorKind :: Credential , * err. kind( ) ) ;
201
+ }
202
+
203
+ async fn run_test ( tokens : & [ AccessToken ] ) {
204
+ let credential = Arc :: new ( MockCredential :: new ( tokens) ) ;
205
+ let policy = BearerTokenCredentialPolicy :: new ( credential. clone ( ) , [ "scope" ] ) ;
206
+ let client = Arc :: new ( MockHttpClient :: new ( move |actual| {
207
+ let credential = credential. clone ( ) ;
208
+ async move {
209
+ let authz = actual. headers ( ) . get_str ( & AUTHORIZATION ) ?;
210
+ // e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header
211
+ let i = credential. get_token_calls ( ) . saturating_sub ( 1 ) ;
212
+ let expected = & credential. tokens [ i] ;
213
+
214
+ assert_eq ! ( format!( "Bearer {}" , expected. token. secret( ) ) , authz) ;
215
+
216
+ Ok ( Response :: from_bytes (
217
+ StatusCode :: Ok ,
218
+ Headers :: new ( ) ,
219
+ Bytes :: new ( ) ,
220
+ ) )
221
+ }
222
+ . boxed ( )
223
+ } ) ) ;
224
+ let transport = Arc :: new ( TransportPolicy :: new ( TransportOptions :: new ( client) ) ) ;
225
+
226
+ let mut handles = vec ! [ ] ;
227
+ for _ in 0 ..4 {
228
+ let policy = policy. clone ( ) ;
229
+ let transport = transport. clone ( ) ;
230
+ let handle = tokio:: spawn ( async move {
231
+ let ctx = Context :: default ( ) ;
232
+ let mut req = Request :: new ( "https://localhost" . parse ( ) . unwrap ( ) , Method :: Get ) ;
233
+ policy
234
+ . send ( & ctx, & mut req, & [ transport. clone ( ) ] )
235
+ . await
236
+ . expect ( "successful request" ) ;
237
+ } ) ;
238
+ handles. push ( handle) ;
239
+ }
240
+
241
+ for handle in handles {
242
+ tokio:: time:: timeout ( Duration :: from_secs ( 2 ) , handle)
243
+ . await
244
+ . expect ( "task timed out after 2 seconds" )
245
+ . expect ( "completed task" ) ;
246
+ }
247
+ }
248
+
249
+ #[ tokio:: test]
250
+ async fn caches_token ( ) {
251
+ run_test ( & [ AccessToken {
252
+ token : Secret :: new ( "fake" . to_string ( ) ) ,
253
+ expires_on : OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 3600 ) ,
254
+ } ] )
255
+ . await ;
256
+ }
257
+
258
+ #[ tokio:: test]
259
+ async fn refreshes_token ( ) {
260
+ run_test ( & [
261
+ AccessToken {
262
+ token : Secret :: new ( "1" . to_string ( ) ) ,
263
+ expires_on : OffsetDateTime :: now_utc ( ) - Duration :: from_secs ( 1 ) ,
264
+ } ,
265
+ AccessToken {
266
+ token : Secret :: new ( "2" . to_string ( ) ) ,
267
+ expires_on : OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( 3600 ) ,
268
+ } ,
269
+ ] )
270
+ . await ;
271
+ }
272
+ }
0 commit comments