-
Notifications
You must be signed in to change notification settings - Fork 283
Remove AccessToken::is_expired() #2611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ use async_lock::RwLock; | |
use async_trait::async_trait; | ||
use std::sync::Arc; | ||
use std::time::Duration; | ||
use typespec_client_core::date::OffsetDateTime; | ||
use typespec_client_core::http::{Context, Request}; | ||
|
||
/// Authentication policy for a bearer token. | ||
|
@@ -23,9 +24,6 @@ pub struct BearerTokenCredentialPolicy { | |
access_token: Arc<RwLock<Option<AccessToken>>>, | ||
} | ||
|
||
/// Default timeout in seconds before refreshing a new token. | ||
const DEFAULT_REFRESH_TIME: Duration = Duration::from_secs(120); | ||
|
||
impl BearerTokenCredentialPolicy { | ||
pub fn new<A, B>(credential: Arc<dyn TokenCredential>, scopes: A) -> Self | ||
where | ||
|
@@ -63,16 +61,44 @@ impl Policy for BearerTokenCredentialPolicy { | |
) -> PolicyResult { | ||
let access_token = self.access_token.read().await; | ||
|
||
if let Some(token) = &(*access_token) { | ||
if token.is_expired(Some(DEFAULT_REFRESH_TIME)) { | ||
match access_token.as_ref() { | ||
None => { | ||
// cache is empty. Upgrade the lock and acquire a token, provided another thread hasn't already done so | ||
drop(access_token); | ||
let mut access_token = self.access_token.write().await; | ||
if access_token.is_none() { | ||
*access_token = Some(self.credential.get_token(&self.scopes()).await?); | ||
} | ||
} | ||
Some(token) if should_refresh(&token.expires_on) => { | ||
// token is expired or within its refresh window. Upgrade the lock and | ||
// acquire a new token, provided another thread hasn't already done so | ||
let expires_on = token.expires_on; | ||
drop(access_token); | ||
let mut access_token = self.access_token.write().await; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We probably don't need to take a write lock this soon - only when we write. Is there any disadvantage to a couple of threads potentially authenticating but only the last writer wins? IIRC, this is how creds in the .NET SDK work, and generally advised if a guarded operation is non-destructive and, in this case, wouldn't get us throttled (I doubt that for auth - not within reason, anyway). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throttling is actually a real danger because IMDS allows only 5 requests per second. We may need to add backoff between proactive refresh attempts to reduce the risk further; a highly concurrent app could hammer IMDS during an outage and get throttled. |
||
*access_token = Some(self.credential.get_token(&self.scopes()).await?); | ||
// access_token shouldn't be None here, but check anyway to guarantee unwrap won't panic | ||
if access_token.is_none() || access_token.as_ref().unwrap().expires_on == expires_on | ||
{ | ||
match self.credential.get_token(&self.scopes()).await { | ||
chlowell marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Ok(new_token) => { | ||
*access_token = Some(new_token); | ||
} | ||
Err(e) | ||
if access_token.is_none() | ||
|| expires_on <= OffsetDateTime::now_utc() => | ||
{ | ||
// propagate this error because we can't proceed without a new token | ||
return Err(e); | ||
} | ||
Err(_) => { | ||
// ignore this error because the cached token is still valid | ||
} | ||
} | ||
} | ||
} | ||
Some(_) => { | ||
// do nothing; cached token is valid and not within its refresh window | ||
} | ||
} else { | ||
drop(access_token); | ||
let mut access_token = self.access_token.write().await; | ||
*access_token = Some(self.credential.get_token(&self.scopes()).await?); | ||
} | ||
|
||
let access_token = self.access_token().await.ok_or_else(|| { | ||
|
@@ -86,3 +112,161 @@ impl Policy for BearerTokenCredentialPolicy { | |
next[0].send(ctx, request, &next[1..]).await | ||
} | ||
} | ||
|
||
fn should_refresh(expires_on: &OffsetDateTime) -> bool { | ||
*expires_on <= OffsetDateTime::now_utc() + Duration::from_secs(300) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::{ | ||
credentials::{Secret, TokenCredential}, | ||
http::{ | ||
headers::{Headers, AUTHORIZATION}, | ||
policies::Policy, | ||
Request, Response, StatusCode, | ||
}, | ||
Bytes, Result, | ||
}; | ||
use async_trait::async_trait; | ||
use azure_core_test::http::MockHttpClient; | ||
use futures::FutureExt; | ||
use std::sync::{ | ||
atomic::{AtomicUsize, Ordering}, | ||
Arc, | ||
}; | ||
use std::time::Duration; | ||
use time::OffsetDateTime; | ||
use typespec_client_core::http::{policies::TransportPolicy, Method, TransportOptions}; | ||
|
||
#[derive(Debug, Clone)] | ||
struct MockCredential { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just FYI: we actually have one in |
||
calls: Arc<AtomicUsize>, | ||
tokens: Arc<[AccessToken]>, | ||
} | ||
|
||
impl MockCredential { | ||
fn new(tokens: &[AccessToken]) -> Self { | ||
Self { | ||
calls: Arc::new(AtomicUsize::new(0)), | ||
tokens: tokens.into(), | ||
} | ||
} | ||
|
||
fn get_token_calls(&self) -> usize { | ||
self.calls.load(Ordering::SeqCst) | ||
} | ||
} | ||
|
||
// ensure the number of get_token() calls matches the number of tokens | ||
// in a test case i.e., that the policy called get_token() as expected | ||
impl Drop for MockCredential { | ||
fn drop(&mut self) { | ||
if !self.tokens.is_empty() { | ||
assert_eq!(self.tokens.len(), self.calls.load(Ordering::SeqCst)); | ||
} | ||
} | ||
} | ||
|
||
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] | ||
#[cfg_attr(not(target_arch = "wasm32"), async_trait)] | ||
impl TokenCredential for MockCredential { | ||
async fn get_token(&self, _scopes: &[&str]) -> Result<AccessToken> { | ||
let i = self.calls.fetch_add(1, Ordering::SeqCst); | ||
self.tokens | ||
.get(i) | ||
.ok_or_else(|| Error::message(ErrorKind::Credential, "no more mock tokens")) | ||
.cloned() | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn authn_error() { | ||
// this mock's get_token() will return an error because it has no tokens | ||
let credential = MockCredential::new(&[]); | ||
let policy = BearerTokenCredentialPolicy::new(Arc::new(credential), ["scope"]); | ||
let client = MockHttpClient::new(|_| panic!("expected an error from get_token")); | ||
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(Arc::new( | ||
client, | ||
)))); | ||
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); | ||
|
||
let err = policy | ||
.send(&Context::default(), &mut req, &[transport.clone()]) | ||
.await | ||
.expect_err("request should fail"); | ||
|
||
assert_eq!(ErrorKind::Credential, *err.kind()); | ||
} | ||
|
||
async fn run_test(tokens: &[AccessToken]) { | ||
let credential = Arc::new(MockCredential::new(tokens)); | ||
let policy = BearerTokenCredentialPolicy::new(credential.clone(), ["scope"]); | ||
let client = Arc::new(MockHttpClient::new(move |actual| { | ||
let credential = credential.clone(); | ||
async move { | ||
let authz = actual.headers().get_str(&AUTHORIZATION)?; | ||
// e.g. if this is the first request, we expect 1 get_token call and tokens[0] in the header | ||
let i = credential.get_token_calls().saturating_sub(1); | ||
let expected = &credential.tokens[i]; | ||
|
||
assert_eq!(format!("Bearer {}", expected.token.secret()), authz); | ||
|
||
Ok(Response::from_bytes( | ||
StatusCode::Ok, | ||
Headers::new(), | ||
Bytes::new(), | ||
)) | ||
} | ||
.boxed() | ||
})); | ||
let transport = Arc::new(TransportPolicy::new(TransportOptions::new(client))); | ||
|
||
let mut handles = vec![]; | ||
for _ in 0..4 { | ||
let policy = policy.clone(); | ||
let transport = transport.clone(); | ||
let handle = tokio::spawn(async move { | ||
let ctx = Context::default(); | ||
let mut req = Request::new("https://localhost".parse().unwrap(), Method::Get); | ||
policy | ||
.send(&ctx, &mut req, &[transport.clone()]) | ||
.await | ||
.expect("successful request"); | ||
}); | ||
handles.push(handle); | ||
} | ||
|
||
for handle in handles { | ||
tokio::time::timeout(Duration::from_secs(2), handle) | ||
.await | ||
.expect("task timed out after 2 seconds") | ||
.expect("completed task"); | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn caches_token() { | ||
run_test(&[AccessToken { | ||
token: Secret::new("fake".to_string()), | ||
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), | ||
}]) | ||
.await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn refreshes_token() { | ||
run_test(&[ | ||
AccessToken { | ||
token: Secret::new("1".to_string()), | ||
expires_on: OffsetDateTime::now_utc() - Duration::from_secs(1), | ||
}, | ||
AccessToken { | ||
token: Secret::new("2".to_string()), | ||
expires_on: OffsetDateTime::now_utc() + Duration::from_secs(3600), | ||
}, | ||
]) | ||
.await; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Especially if this was a customer-filed bug, we reference the bug number such that it renders as
(#{num})
. If not a customer-filed issue, no need.