Skip to content

Commit 98638b1

Browse files
authored
Remove federated_credentials_flow (#2500)
1 parent 20181be commit 98638b1

File tree

5 files changed

+208
-184
lines changed

5 files changed

+208
-184
lines changed

sdk/identity/azure_identity/src/client_secret_credential.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
use crate::{credentials::cache::TokenCache, EntraIdTokenResponse};
5-
use crate::{EntraIdErrorResponse, TokenCredentialOptions};
6-
use azure_core::http::{Response, StatusCode};
4+
use crate::{
5+
credentials::cache::TokenCache, deserialize, EntraIdErrorResponse, EntraIdTokenResponse,
6+
TokenCredentialOptions,
7+
};
8+
use azure_core::http::StatusCode;
79
use azure_core::Result;
810
use azure_core::{
911
credentials::{AccessToken, Secret, TokenCredential},
@@ -19,6 +21,8 @@ use std::{str, sync::Arc};
1921
use time::OffsetDateTime;
2022
use url::form_urlencoded;
2123

24+
const CLIENT_SECRET_CREDENTIAL: &str = "ClientSecretCredential";
25+
2226
/// Options for constructing a new [`ClientSecretCredential`].
2327
#[derive(Debug, Default)]
2428
pub struct ClientSecretCredentialOptions {
@@ -83,18 +87,24 @@ impl ClientSecretCredential {
8387

8488
match res.status() {
8589
StatusCode::Ok => {
86-
let token_response: EntraIdTokenResponse = deserialize(res).await?;
90+
let token_response: EntraIdTokenResponse =
91+
deserialize(CLIENT_SECRET_CREDENTIAL, res).await?;
8792
Ok(AccessToken::new(
8893
token_response.access_token,
8994
OffsetDateTime::now_utc() + Duration::from_secs(token_response.expires_in),
9095
))
9196
}
9297
_ => {
93-
let error_response: EntraIdErrorResponse = deserialize(res).await?;
94-
let mut message = "ClientSecretCredential authentication failed".to_string();
95-
if !error_response.error_description.is_empty() {
96-
message = format!("{}: {}", message, error_response.error_description);
97-
}
98+
let error_response: EntraIdErrorResponse =
99+
deserialize(CLIENT_SECRET_CREDENTIAL, res).await?;
100+
let message = if error_response.error_description.is_empty() {
101+
format!("{} authentication failed.", CLIENT_SECRET_CREDENTIAL)
102+
} else {
103+
format!(
104+
"{} authentication failed. {}",
105+
CLIENT_SECRET_CREDENTIAL, error_response.error_description
106+
)
107+
};
98108
Err(Error::message(ErrorKind::Credential, message))
99109
}
100110
}
@@ -114,19 +124,6 @@ impl TokenCredential for ClientSecretCredential {
114124
}
115125
}
116126

117-
async fn deserialize<T>(res: Response) -> Result<T>
118-
where
119-
T: serde::de::DeserializeOwned,
120-
{
121-
let t: T = res
122-
.into_json_body()
123-
.await
124-
.with_context(ErrorKind::Credential, || {
125-
"ClientSecretCredential authentication failed: invalid response"
126-
})?;
127-
Ok(t)
128-
}
129-
130127
#[cfg(test)]
131128
mod tests {
132129
use super::*;

sdk/identity/azure_identity/src/credentials/client_assertion_credentials.rs

Lines changed: 168 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
use crate::{credentials::cache::TokenCache, federated_credentials_flow, TokenCredentialOptions};
4+
use crate::{
5+
credentials::cache::TokenCache, deserialize, validate_not_empty, validate_tenant_id,
6+
EntraIdErrorResponse, EntraIdTokenResponse, TokenCredentialOptions,
7+
};
58
use azure_core::{
69
credentials::{AccessToken, TokenCredential},
710
error::{ErrorKind, ResultExt},
11+
http::{
12+
headers::{self, content_type},
13+
Method, Request, StatusCode, Url,
14+
},
15+
Error,
816
};
917
use std::{fmt::Debug, str, sync::Arc, time::Duration};
1018
use time::OffsetDateTime;
19+
use url::form_urlencoded;
20+
21+
const ASSERTION_TYPE: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
22+
const CLIENT_ASSERTION_CREDENTIAL: &str = "ClientAssertionCredential";
1123

1224
/// Enables authentication of a Microsoft Entra service principal using a signed client assertion.
1325
#[derive(Debug)]
1426
pub struct ClientAssertionCredential<C> {
15-
tenant_id: String,
1627
client_id: String,
28+
endpoint: Url,
1729
assertion: C,
1830
cache: TokenCache,
19-
options: ClientAssertionCredentialOptions,
31+
options: TokenCredentialOptions,
2032
}
2133

2234
/// Options for constructing a new [`ClientAssertionCredential`].
@@ -68,59 +80,92 @@ impl<C: ClientAssertion> ClientAssertionCredential<C> {
6880
assertion: C,
6981
options: Option<ClientAssertionCredentialOptions>,
7082
) -> azure_core::Result<Self> {
83+
validate_tenant_id(&tenant_id)?;
84+
validate_not_empty(&client_id, "no client ID specified")?;
85+
let options = options.unwrap_or_default().credential_options;
86+
let endpoint = options
87+
.authority_host()?
88+
.join(&format!("/{tenant_id}/oauth2/v2.0/token"))
89+
.with_context(ErrorKind::DataConversion, || {
90+
format!("tenant_id {tenant_id} could not be URL encoded")
91+
})?;
7192
Ok(Self {
72-
tenant_id,
7393
client_id,
7494
assertion,
95+
endpoint,
7596
cache: TokenCache::new(),
76-
options: options.unwrap_or_default(),
97+
options,
7798
})
7899
}
79100

80-
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
81-
let token = self.assertion.secret().await?;
82-
let credential_options = &self.options.credential_options;
83-
let res: AccessToken = federated_credentials_flow::authorize(
84-
credential_options.http_client().clone(),
85-
&self.client_id,
86-
&token,
87-
scopes,
88-
&self.tenant_id,
89-
&credential_options.authority_host()?,
90-
)
91-
.await
92-
.map(|r| {
93-
AccessToken::new(
94-
r.access_token().clone(),
95-
OffsetDateTime::now_utc() + Duration::from_secs(r.expires_in),
96-
)
97-
})
98-
.context(ErrorKind::Credential, "request token error")?;
99-
Ok(res)
101+
async fn get_token_impl(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
102+
let mut req = Request::new(self.endpoint.clone(), Method::Post);
103+
req.insert_header(
104+
headers::CONTENT_TYPE,
105+
content_type::APPLICATION_X_WWW_FORM_URLENCODED,
106+
);
107+
let assertion = self.assertion.secret().await?;
108+
let encoded: String = form_urlencoded::Serializer::new(String::new())
109+
.append_pair("client_assertion", assertion.as_str())
110+
.append_pair("client_assertion_type", ASSERTION_TYPE)
111+
.append_pair("client_id", self.client_id.as_str())
112+
.append_pair("grant_type", "client_credentials")
113+
.append_pair("scope", &scopes.join(" "))
114+
.finish();
115+
req.set_body(encoded);
116+
117+
let res = self.options.http_client.execute_request(&req).await?;
118+
119+
match res.status() {
120+
StatusCode::Ok => {
121+
let token_response: EntraIdTokenResponse =
122+
deserialize(CLIENT_ASSERTION_CREDENTIAL, res).await?;
123+
Ok(AccessToken::new(
124+
token_response.access_token,
125+
OffsetDateTime::now_utc() + Duration::from_secs(token_response.expires_in),
126+
))
127+
}
128+
_ => {
129+
let error_response: EntraIdErrorResponse =
130+
deserialize(CLIENT_ASSERTION_CREDENTIAL, res).await?;
131+
let message = if error_response.error_description.is_empty() {
132+
format!("{} authentication failed.", CLIENT_ASSERTION_CREDENTIAL)
133+
} else {
134+
format!(
135+
"{} authentication failed. {}",
136+
CLIENT_ASSERTION_CREDENTIAL, error_response.error_description
137+
)
138+
};
139+
Err(Error::message(ErrorKind::Credential, message))
140+
}
141+
}
100142
}
101143
}
102144

103145
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
104146
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
105147
impl<C: ClientAssertion> TokenCredential for ClientAssertionCredential<C> {
106148
async fn get_token(&self, scopes: &[&str]) -> azure_core::Result<AccessToken> {
107-
self.cache.get_token(scopes, self.get_token(scopes)).await
149+
self.cache
150+
.get_token(scopes, self.get_token_impl(scopes))
151+
.await
108152
}
109153
}
110154

111155
#[cfg(test)]
112156
pub(crate) mod tests {
113-
use std::collections::HashMap;
114-
115157
use super::*;
116158
use crate::tests::*;
117159
use azure_core::{
118160
authority_hosts::AZURE_PUBLIC_CLOUD,
119161
http::{
120-
headers::{self, content_type},
121-
Body, Method, Request,
162+
headers::{self, content_type, Headers},
163+
Body, Method, Request, Response,
122164
},
165+
Bytes,
123166
};
167+
use std::{collections::HashMap, time::SystemTime};
168+
use time::UtcOffset;
124169
use url::form_urlencoded;
125170

126171
pub const FAKE_ASSERTION: &str = "fake assertion";
@@ -140,10 +185,7 @@ pub(crate) mod tests {
140185
);
141186
let expected_params = [
142187
("client_assertion", FAKE_ASSERTION),
143-
(
144-
"client_assertion_type",
145-
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
146-
),
188+
("client_assertion_type", ASSERTION_TYPE),
147189
("client_id", FAKE_CLIENT_ID),
148190
("grant_type", "client_credentials"),
149191
("scope", &LIVE_TEST_SCOPES.join(" ")),
@@ -166,4 +208,96 @@ pub(crate) mod tests {
166208
Ok(())
167209
}
168210
}
211+
212+
#[derive(Debug)]
213+
struct MockAssertion {}
214+
215+
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
216+
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
217+
impl ClientAssertion for MockAssertion {
218+
async fn secret(&self) -> azure_core::Result<String> {
219+
Ok(FAKE_ASSERTION.to_string())
220+
}
221+
}
222+
223+
#[tokio::test]
224+
async fn get_token_error() {
225+
let expected = "error description from the response";
226+
let mock = MockSts::new(
227+
vec![Response::from_bytes(
228+
StatusCode::BadRequest,
229+
Headers::default(),
230+
Bytes::from(format!(
231+
r#"{{"error":"invalid_request","error_description":"{}","error_codes":[50027],"timestamp":"2025-04-18 16:04:37Z","trace_id":"...","correlation_id":"...","error_uri":"https://login.microsoftonline.com/error?code=50027"}}"#,
232+
expected
233+
)),
234+
)],
235+
Some(Arc::new(is_valid_request())),
236+
);
237+
let credential = ClientAssertionCredential::new(
238+
FAKE_TENANT_ID.to_string(),
239+
FAKE_CLIENT_ID.to_string(),
240+
MockAssertion {},
241+
Some(ClientAssertionCredentialOptions {
242+
credential_options: TokenCredentialOptions {
243+
http_client: Arc::new(mock),
244+
..Default::default()
245+
},
246+
..Default::default()
247+
}),
248+
)
249+
.expect("valid credential");
250+
251+
let error = credential
252+
.get_token(LIVE_TEST_SCOPES)
253+
.await
254+
.expect_err("authentication error");
255+
assert!(matches!(error.kind(), ErrorKind::Credential));
256+
assert!(
257+
error.to_string().contains(expected),
258+
"expected error description from the response, got '{}'",
259+
error
260+
);
261+
}
262+
263+
#[tokio::test]
264+
async fn get_token_success() {
265+
let mock = MockSts::new(
266+
vec![Response::from_bytes(
267+
StatusCode::Ok,
268+
Headers::default(),
269+
Bytes::from(format!(
270+
r#"{{"access_token":"{}","expires_in":3600,"token_type":"Bearer"}}"#,
271+
FAKE_TOKEN
272+
)),
273+
)],
274+
Some(Arc::new(is_valid_request())),
275+
);
276+
let credential = ClientAssertionCredential::new(
277+
FAKE_TENANT_ID.to_string(),
278+
FAKE_CLIENT_ID.to_string(),
279+
MockAssertion {},
280+
Some(ClientAssertionCredentialOptions {
281+
credential_options: TokenCredentialOptions {
282+
http_client: Arc::new(mock),
283+
..Default::default()
284+
},
285+
..Default::default()
286+
}),
287+
)
288+
.expect("valid credential");
289+
290+
let token = credential.get_token(LIVE_TEST_SCOPES).await.expect("token");
291+
assert_eq!(FAKE_TOKEN, token.token.secret());
292+
assert!(token.expires_on > SystemTime::now());
293+
assert_eq!(UtcOffset::UTC, token.expires_on.offset());
294+
295+
// MockSts will return an error if the credential sends another request
296+
let cached_token = credential
297+
.get_token(LIVE_TEST_SCOPES)
298+
.await
299+
.expect("cached token");
300+
assert_eq!(token.token.secret(), cached_token.token.secret());
301+
assert_eq!(token.expires_on, cached_token.expires_on);
302+
}
169303
}

0 commit comments

Comments
 (0)