Skip to content

Commit 71fe449

Browse files
authored
Remove AccessToken::is_expired() (#2611)
1 parent a9a5544 commit 71fe449

File tree

4 files changed

+208
-21
lines changed

4 files changed

+208
-21
lines changed

sdk/core/azure_core/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
### Breaking Changes
1010

11+
- Removed `AccessToken::is_expired()`
12+
1113
### Bugs Fixed
1214

15+
- `BearerTokenCredentialPolicy` returns an error when a proactive token refresh attempt fails
16+
1317
### Other Changes
1418

1519
## 0.24.0 (2025-05-02)

sdk/core/azure_core/src/credentials.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//! Azure authentication and authorization.
55
66
use serde::{Deserialize, Serialize};
7-
use std::{borrow::Cow, fmt::Debug, time::Duration};
7+
use std::{borrow::Cow, fmt::Debug};
88
use typespec_client_core::date::OffsetDateTime;
99

1010
/// Default Azure authorization scope.
@@ -85,13 +85,6 @@ impl AccessToken {
8585
expires_on,
8686
}
8787
}
88-
89-
/// Check if the token is expired within a given duration.
90-
///
91-
/// If no duration is provided, then the default duration of 30 seconds is used.
92-
pub fn is_expired(&self, window: Option<Duration>) -> bool {
93-
self.expires_on < OffsetDateTime::now_utc() + window.unwrap_or(Duration::from_secs(30))
94-
}
9588
}
9689

9790
/// Represents a credential capable of providing an OAuth token.

sdk/core/azure_core/src/http/policies/bearer_token_policy.rs

Lines changed: 194 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use async_lock::RwLock;
1313
use async_trait::async_trait;
1414
use std::sync::Arc;
1515
use std::time::Duration;
16+
use typespec_client_core::date::OffsetDateTime;
1617
use typespec_client_core::http::{Context, Request};
1718

1819
/// Authentication policy for a bearer token.
@@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy {
2324
access_token: Arc<RwLock<Option<AccessToken>>>,
2425
}
2526

26-
/// Default timeout in seconds before refreshing a new token.
27-
const DEFAULT_REFRESH_TIME: Duration = Duration::from_secs(120);
28-
2927
impl BearerTokenCredentialPolicy {
3028
pub fn new<A, B>(credential: Arc<dyn TokenCredential>, scopes: A) -> Self
3129
where
@@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy {
6361
) -> PolicyResult {
6462
let access_token = self.access_token.read().await;
6563

66-
if let Some(token) = &(*access_token) {
67-
if token.is_expired(Some(DEFAULT_REFRESH_TIME)) {
64+
match access_token.as_ref() {
65+
None => {
66+
// cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so
67+
drop(access_token);
68+
let mut access_token = self.access_token.write().await;
69+
if access_token.is_none() {
70+
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
71+
}
72+
}
73+
Some(token) if should_refresh(&token.expires_on) => {
74+
// token is expired or within its refresh window. Upgrade the lock and
75+
// acquire a new token, provided another thread hasn't already done so
76+
let expires_on = token.expires_on;
6877
drop(access_token);
6978
let mut access_token = self.access_token.write().await;
70-
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
79+
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
80+
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on
81+
{
82+
match self.credential.get_token(&self.scopes()).await {
83+
Ok(new_token) => {
84+
*access_token = Some(new_token);
85+
}
86+
Err(e)
87+
if access_token.is_none()
88+
|| expires_on <= OffsetDateTime::now_utc() =>
89+
{
90+
// propagate this error because we can't proceed without a new token
91+
return Err(e);
92+
}
93+
Err(_) => {
94+
// ignore this error because the cached token is still valid
95+
}
96+
}
97+
}
98+
}
99+
Some(_) => {
100+
// do nothing; cached token is valid and not within its refresh window
71101
}
72-
} else {
73-
drop(access_token);
74-
let mut access_token = self.access_token.write().await;
75-
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
76102
}
77103

78104
let access_token = self.access_token().await.ok_or_else(|| {
@@ -86,3 +112,161 @@ impl Policy for BearerTokenCredentialPolicy {
86112
next[0].send(ctx, request, &next[1..]).await
87113
}
88114
}
115+
116+
fn should_refresh(expires_on: &OffsetDateTime) -> bool {
117+
*expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300)
118+
}
119+
120+
#[cfg(test)]
121+
mod tests {
122+
use super::*;
123+
use crate::{
124+
credentials::{Secret, TokenCredential},
125+
http::{
126+
headers::{Headers, AUTHORIZATION},
127+
policies::Policy,
128+
Request, Response, StatusCode,
129+
},
130+
Bytes, Result,
131+
};
132+
use async_trait::async_trait;
133+
use azure_core_test::http::MockHttpClient;
134+
use futures::FutureExt;
135+
use std::sync::{
136+
atomic::{AtomicUsize, Ordering},
137+
Arc,
138+
};
139+
use std::time::Duration;
140+
use time::OffsetDateTime;
141+
use typespec_client_core::http::{policies::TransportPolicy, Method, TransportOptions};
142+
143+
#[derive(Debug, Clone)]
144+
struct MockCredential {
145+
calls: Arc<AtomicUsize>,
146+
tokens: Arc<[AccessToken]>,
147+
}
148+
149+
impl MockCredential {
150+
fn new(tokens: &[AccessToken]) -> Self {
151+
Self {
152+
calls: Arc::new(AtomicUsize::new(0)),
153+
tokens: tokens.into(),
154+
}
155+
}
156+
157+
fn get_token_calls(&self) -> usize {
158+
self.calls.load(Ordering::SeqCst)
159+
}
160+
}
161+
162+
// ensure the number of get_token() calls matches the number of tokens
163+
// in a test case i.e., that the policy called get_token() as expected
164+
impl Drop for MockCredential {
165+
fn drop(&mut self) {
166+
if !self.tokens.is_empty() {
167+
assert_eq!(self.tokens.len(), self.calls.load(Ordering::SeqCst));
168+
}
169+
}
170+
}
171+
172+
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
173+
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
174+
impl TokenCredential for MockCredential {
175+
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
176+
let i = self.calls.fetch_add(1, Ordering::SeqCst);
177+
self.tokens
178+
.get(i)
179+
.ok_or_else(|| Error::message(ErrorKind::Credential, "no more mock tokens"))
180+
.cloned()
181+
}
182+
}
183+
184+
#[tokio::test]
185+
async fn authn_error() {
186+
// this mock's get_token() will return an error because it has no tokens
187+
let credential = MockCredential::new(&[]);
188+
let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]);
189+
let client = MockHttpClient::new(|_| panic!("expected an error from get_token"));
190+
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(Arc::new(
191+
client,
192+
))));
193+
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);
194+
195+
let err = policy
196+
.send(&Context::default(), &mut req, &[transport.clone()])
197+
.await
198+
.expect_err("request should fail");
199+
200+
assert_eq!(ErrorKind::Credential, *err.kind());
201+
}
202+
203+
async fn run_test(tokens: &[AccessToken]) {
204+
let credential = Arc::new(MockCredential::new(tokens));
205+
let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]);
206+
let client = Arc::new(MockHttpClient::new(move |actual| {
207+
let credential = credential.clone();
208+
async move {
209+
let authz = actual.headers().get_str(&AUTHORIZATION)?;
210+
// e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header
211+
let i = credential.get_token_calls().saturating_sub(1);
212+
let expected = &credential.tokens[i];
213+
214+
assert_eq!(format!("Bearer {}", expected.token.secret()), authz);
215+
216+
Ok(Response::from_bytes(
217+
StatusCode::Ok,
218+
Headers::new(),
219+
Bytes::new(),
220+
))
221+
}
222+
.boxed()
223+
}));
224+
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client)));
225+
226+
let mut handles = vec![];
227+
for _ in 0..4 {
228+
let policy = policy.clone();
229+
let transport = transport.clone();
230+
let handle = tokio::spawn(async move {
231+
let ctx = Context::default();
232+
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);
233+
policy
234+
.send(&ctx, &mut req, &[transport.clone()])
235+
.await
236+
.expect("successful request");
237+
});
238+
handles.push(handle);
239+
}
240+
241+
for handle in handles {
242+
tokio::time::timeout(Duration::from_secs(2), handle)
243+
.await
244+
.expect("task timed out after 2 seconds")
245+
.expect("completed task");
246+
}
247+
}
248+
249+
#[tokio::test]
250+
async fn caches_token() {
251+
run_test(&[AccessToken {
252+
token: Secret::new("fake".to_string()),
253+
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600),
254+
}])
255+
.await;
256+
}
257+
258+
#[tokio::test]
259+
async fn refreshes_token() {
260+
run_test(&[
261+
AccessToken {
262+
token: Secret::new("1".to_string()),
263+
expires_on: OffsetDateTime::now_utc() - Duration::from_secs(1),
264+
},
265+
AccessToken {
266+
token: Secret::new("2".to_string()),
267+
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600),
268+
},
269+
])
270+
.await;
271+
}
272+
}

sdk/identity/azure_identity/src/credentials/cache.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use async_lock::RwLock;
55
use azure_core::credentials::AccessToken;
66
use futures::Future;
77
use std::collections::HashMap;
8+
use std::time::Duration;
89
use tracing::trace;
10+
use typespec_client_core::date::OffsetDateTime;
911

1012
#[derive(Debug)]
1113
pub(crate) struct TokenCache(RwLock<HashMap<Vec<String>, AccessToken>>);
@@ -24,7 +26,7 @@ impl TokenCache {
2426
let token_cache = self.0.read().await;
2527
let scopes = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
2628
if let Some(token) = token_cache.get(&scopes) {
27-
if !token.is_expired(None) {
29+
if !should_refresh(token) {
2830
trace!("returning cached token");
2931
return Ok(token.clone());
3032
}
@@ -37,7 +39,7 @@ impl TokenCache {
3739
// check again in case another thread refreshed the token while we were
3840
// waiting on the write lock
3941
if let Some(token) = token_cache.get(&scopes) {
40-
if !token.is_expired(None) {
42+
if !should_refresh(token) {
4143
trace!("returning token that was updated while waiting on write lock");
4244
return Ok(token.clone());
4345
}
@@ -61,6 +63,10 @@ impl Default for TokenCache {
6163
}
6264
}
6365

66+
fn should_refresh(token: &AccessToken) -> bool {
67+
token.expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300)
68+
}
69+
6470
#[cfg(test)]
6571
mod tests {
6672
use super::*;
@@ -106,7 +112,7 @@ mod tests {
106112
let resource1 = &[STORAGE_TOKEN_SCOPE];
107113
let resource2 = &[IOTHUB_TOKEN_SCOPE];
108114
let secret_string = "test-token";
109-
let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(300);
115+
let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(3600);
110116
let access_token = AccessToken::new(Secret::new(secret_string), expires_on);
111117

112118
let mock_credential = MockCredential::new(access_token);

0 commit comments

Comments
 (0)