Skip to content

Remove AccessToken::is_expired() #2611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/core/azure_core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

### Breaking Changes

- Removed `AccessToken::is_expired()`

### Bugs Fixed

- `BearerTokenCredentialPolicy` returns an error when a proactive token refresh attempt fails
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially if this was a customer-filed bug, we reference the bug number such that it renders as (#{num}). If not a customer-filed issue, no need.


### Other Changes

## 0.24.0 (2025-05-02)
Expand Down
9 changes: 1 addition & 8 deletions sdk/core/azure_core/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! Azure authentication and authorization.

use serde::{Deserialize, Serialize};
use std::{borrow::Cow, fmt::Debug, time::Duration};
use std::{borrow::Cow, fmt::Debug};
use typespec_client_core::date::OffsetDateTime;

/// Default Azure authorization scope.
Expand Down Expand Up @@ -85,13 +85,6 @@ impl AccessToken {
expires_on,
}
}

/// Check if the token is expired within a given duration.
///
/// If no duration is provided, then the default duration of 30 seconds is used.
pub fn is_expired(&self, window: Option<Duration>) -> bool {
self.expires_on < OffsetDateTime::now_utc() + window.unwrap_or(Duration::from_secs(30))
}
}

/// Represents a credential capable of providing an OAuth token.
Expand Down
204 changes: 194 additions & 10 deletions sdk/core/azure_core/src/http/policies/bearer_token_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use async_lock::RwLock;
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use typespec_client_core::date::OffsetDateTime;
use typespec_client_core::http::{Context, Request};

/// Authentication policy for a bearer token.
Expand All @@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy {
access_token: Arc<RwLock<Option<AccessToken>>>,
}

/// Default timeout in seconds before refreshing a new token.
const DEFAULT_REFRESH_TIME: Duration = Duration::from_secs(120);

impl BearerTokenCredentialPolicy {
pub fn new<A, B>(credential: Arc<dyn TokenCredential>, scopes: A) -> Self
where
Expand Down Expand Up @@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy {
) -> PolicyResult {
let access_token = self.access_token.read().await;

if let Some(token) = &(*access_token) {
if token.is_expired(Some(DEFAULT_REFRESH_TIME)) {
match access_token.as_ref() {
None => {
// cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so
drop(access_token);
let mut access_token = self.access_token.write().await;
if access_token.is_none() {
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
}
}
Some(token) if should_refresh(&token.expires_on) => {
// token is expired or within its refresh window. Upgrade the lock and
// acquire a new token, provided another thread hasn't already done so
let expires_on = token.expires_on;
drop(access_token);
let mut access_token = self.access_token.write().await;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need to take a write lock this soon - only when we write. Is there any disadvantage to a couple of threads potentially authenticating but only the last writer wins? IIRC, this is how creds in the .NET SDK work, and generally advised if a guarded operation is non-destructive and, in this case, wouldn't get us throttled (I doubt that for auth - not within reason, anyway).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throttling is actually a real danger because IMDS allows only 5 requests per second. We may need to add backoff between proactive refresh attempts to reduce the risk further; a highly concurrent app could hammer IMDS during an outage and get throttled.

*access_token = Some(self.credential.get_token(&self.scopes()).await?);
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on
{
match self.credential.get_token(&self.scopes()).await {
Ok(new_token) => {
*access_token = Some(new_token);
}
Err(e)
if access_token.is_none()
|| expires_on <= OffsetDateTime::now_utc() =>
{
// propagate this error because we can't proceed without a new token
return Err(e);
}
Err(_) => {
// ignore this error because the cached token is still valid
}
}
}
}
Some(_) => {
// do nothing; cached token is valid and not within its refresh window
}
} else {
drop(access_token);
let mut access_token = self.access_token.write().await;
*access_token = Some(self.credential.get_token(&self.scopes()).await?);
}

let access_token = self.access_token().await.ok_or_else(|| {
Expand All @@ -86,3 +112,161 @@ impl Policy for BearerTokenCredentialPolicy {
next[0].send(ctx, request, &next[1..]).await
}
}

fn should_refresh(expires_on: &OffsetDateTime) -> bool {
*expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
credentials::{Secret, TokenCredential},
http::{
headers::{Headers, AUTHORIZATION},
policies::Policy,
Request, Response, StatusCode,
},
Bytes, Result,
};
use async_trait::async_trait;
use azure_core_test::http::MockHttpClient;
use futures::FutureExt;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::Duration;
use time::OffsetDateTime;
use typespec_client_core::http::{policies::TransportPolicy, Method, TransportOptions};

#[derive(Debug, Clone)]
struct MockCredential {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI: we actually have one in azure_core_test, though maybe this one is purpose-built for this task.

calls: Arc<AtomicUsize>,
tokens: Arc<[AccessToken]>,
}

impl MockCredential {
fn new(tokens: &[AccessToken]) -> Self {
Self {
calls: Arc::new(AtomicUsize::new(0)),
tokens: tokens.into(),
}
}

fn get_token_calls(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
}

// ensure the number of get_token() calls matches the number of tokens
// in a test case i.e., that the policy called get_token() as expected
impl Drop for MockCredential {
fn drop(&mut self) {
if !self.tokens.is_empty() {
assert_eq!(self.tokens.len(), self.calls.load(Ordering::SeqCst));
}
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl TokenCredential for MockCredential {
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> {
let i = self.calls.fetch_add(1, Ordering::SeqCst);
self.tokens
.get(i)
.ok_or_else(|| Error::message(ErrorKind::Credential, "no more mock tokens"))
.cloned()
}
}

#[tokio::test]
async fn authn_error() {
// this mock's get_token() will return an error because it has no tokens
let credential = MockCredential::new(&[]);
let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]);
let client = MockHttpClient::new(|_| panic!("expected an error from get_token"));
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(Arc::new(
client,
))));
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);

let err = policy
.send(&Context::default(), &mut req, &[transport.clone()])
.await
.expect_err("request should fail");

assert_eq!(ErrorKind::Credential, *err.kind());
}

async fn run_test(tokens: &[AccessToken]) {
let credential = Arc::new(MockCredential::new(tokens));
let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]);
let client = Arc::new(MockHttpClient::new(move |actual| {
let credential = credential.clone();
async move {
let authz = actual.headers().get_str(&AUTHORIZATION)?;
// e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header
let i = credential.get_token_calls().saturating_sub(1);
let expected = &credential.tokens[i];

assert_eq!(format!("Bearer {}", expected.token.secret()), authz);

Ok(Response::from_bytes(
StatusCode::Ok,
Headers::new(),
Bytes::new(),
))
}
.boxed()
}));
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client)));

let mut handles = vec![];
for _ in 0..4 {
let policy = policy.clone();
let transport = transport.clone();
let handle = tokio::spawn(async move {
let ctx = Context::default();
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get);
policy
.send(&ctx, &mut req, &[transport.clone()])
.await
.expect("successful request");
});
handles.push(handle);
}

for handle in handles {
tokio::time::timeout(Duration::from_secs(2), handle)
.await
.expect("task timed out after 2 seconds")
.expect("completed task");
}
}

#[tokio::test]
async fn caches_token() {
run_test(&[AccessToken {
token: Secret::new("fake".to_string()),
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600),
}])
.await;
}

#[tokio::test]
async fn refreshes_token() {
run_test(&[
AccessToken {
token: Secret::new("1".to_string()),
expires_on: OffsetDateTime::now_utc() - Duration::from_secs(1),
},
AccessToken {
token: Secret::new("2".to_string()),
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600),
},
])
.await;
}
}
12 changes: 9 additions & 3 deletions sdk/identity/azure_identity/src/credentials/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use async_lock::RwLock;
use azure_core::credentials::AccessToken;
use futures::Future;
use std::collections::HashMap;
use std::time::Duration;
use tracing::trace;
use typespec_client_core::date::OffsetDateTime;

#[derive(Debug)]
pub(crate) struct TokenCache(RwLock<HashMap<Vec<String>, AccessToken>>);
Expand All @@ -24,7 +26,7 @@ impl TokenCache {
let token_cache = self.0.read().await;
let scopes = scopes.iter().map(ToString::to_string).collect::<Vec<_>>();
if let Some(token) = token_cache.get(&scopes) {
if !token.is_expired(None) {
if !should_refresh(token) {
trace!("returning cached token");
return Ok(token.clone());
}
Expand All @@ -37,7 +39,7 @@ impl TokenCache {
// check again in case another thread refreshed the token while we were
// waiting on the write lock
if let Some(token) = token_cache.get(&scopes) {
if !token.is_expired(None) {
if !should_refresh(token) {
trace!("returning token that was updated while waiting on write lock");
return Ok(token.clone());
}
Expand All @@ -61,6 +63,10 @@ impl Default for TokenCache {
}
}

fn should_refresh(token: &AccessToken) -> bool {
token.expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -106,7 +112,7 @@ mod tests {
let resource1 = &[STORAGE_TOKEN_SCOPE];
let resource2 = &[IOTHUB_TOKEN_SCOPE];
let secret_string = "test-token";
let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(300);
let expires_on = OffsetDateTime::now_utc() + Duration::from_secs(3600);
let access_token = AccessToken::new(Secret::new(secret_string), expires_on);

let mock_credential = MockCredential::new(access_token);
Expand Down