1
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
2
// Licensed under the MIT License.
3
3
4
- use crate :: { credentials:: cache:: TokenCache , federated_credentials_flow, TokenCredentialOptions } ;
4
+ use crate :: {
5
+ credentials:: cache:: TokenCache , deserialize, validate_not_empty, validate_tenant_id,
6
+ EntraIdErrorResponse , EntraIdTokenResponse , TokenCredentialOptions ,
7
+ } ;
5
8
use azure_core:: {
6
9
credentials:: { AccessToken , TokenCredential } ,
7
10
error:: { ErrorKind , ResultExt } ,
11
+ http:: {
12
+ headers:: { self , content_type} ,
13
+ Method , Request , StatusCode , Url ,
14
+ } ,
15
+ Error ,
8
16
} ;
9
17
use std:: { fmt:: Debug , str, sync:: Arc , time:: Duration } ;
10
18
use time:: OffsetDateTime ;
19
+ use url:: form_urlencoded;
20
+
21
+ const ASSERTION_TYPE : & str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ;
22
+ const CLIENT_ASSERTION_CREDENTIAL : & str = "ClientAssertionCredential" ;
11
23
12
24
/// Enables authentication of a Microsoft Entra service principal using a signed client assertion.
13
25
#[ derive( Debug ) ]
14
26
pub struct ClientAssertionCredential < C > {
15
- tenant_id : String ,
16
27
client_id : String ,
28
+ endpoint : Url ,
17
29
assertion : C ,
18
30
cache : TokenCache ,
19
- options : ClientAssertionCredentialOptions ,
31
+ options : TokenCredentialOptions ,
20
32
}
21
33
22
34
/// Options for constructing a new [`ClientAssertionCredential`].
@@ -68,59 +80,92 @@ impl<C: ClientAssertion> ClientAssertionCredential<C> {
68
80
assertion : C ,
69
81
options : Option < ClientAssertionCredentialOptions > ,
70
82
) -> azure_core:: Result < Self > {
83
+ validate_tenant_id ( & tenant_id) ?;
84
+ validate_not_empty ( & client_id, "no client ID specified" ) ?;
85
+ let options = options. unwrap_or_default ( ) . credential_options ;
86
+ let endpoint = options
87
+ . authority_host ( ) ?
88
+ . join ( & format ! ( "/{tenant_id}/oauth2/v2.0/token" ) )
89
+ . with_context ( ErrorKind :: DataConversion , || {
90
+ format ! ( "tenant_id {tenant_id} could not be URL encoded" )
91
+ } ) ?;
71
92
Ok ( Self {
72
- tenant_id,
73
93
client_id,
74
94
assertion,
95
+ endpoint,
75
96
cache : TokenCache :: new ( ) ,
76
- options : options . unwrap_or_default ( ) ,
97
+ options,
77
98
} )
78
99
}
79
100
80
- async fn get_token ( & self , scopes : & [ & str ] ) -> azure_core:: Result < AccessToken > {
81
- let token = self . assertion . secret ( ) . await ?;
82
- let credential_options = & self . options . credential_options ;
83
- let res: AccessToken = federated_credentials_flow:: authorize (
84
- credential_options. http_client ( ) . clone ( ) ,
85
- & self . client_id ,
86
- & token,
87
- scopes,
88
- & self . tenant_id ,
89
- & credential_options. authority_host ( ) ?,
90
- )
91
- . await
92
- . map ( |r| {
93
- AccessToken :: new (
94
- r. access_token ( ) . clone ( ) ,
95
- OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( r. expires_in ) ,
96
- )
97
- } )
98
- . context ( ErrorKind :: Credential , "request token error" ) ?;
99
- Ok ( res)
101
+ async fn get_token_impl ( & self , scopes : & [ & str ] ) -> azure_core:: Result < AccessToken > {
102
+ let mut req = Request :: new ( self . endpoint . clone ( ) , Method :: Post ) ;
103
+ req. insert_header (
104
+ headers:: CONTENT_TYPE ,
105
+ content_type:: APPLICATION_X_WWW_FORM_URLENCODED ,
106
+ ) ;
107
+ let assertion = self . assertion . secret ( ) . await ?;
108
+ let encoded: String = form_urlencoded:: Serializer :: new ( String :: new ( ) )
109
+ . append_pair ( "client_assertion" , assertion. as_str ( ) )
110
+ . append_pair ( "client_assertion_type" , ASSERTION_TYPE )
111
+ . append_pair ( "client_id" , self . client_id . as_str ( ) )
112
+ . append_pair ( "grant_type" , "client_credentials" )
113
+ . append_pair ( "scope" , & scopes. join ( " " ) )
114
+ . finish ( ) ;
115
+ req. set_body ( encoded) ;
116
+
117
+ let res = self . options . http_client . execute_request ( & req) . await ?;
118
+
119
+ match res. status ( ) {
120
+ StatusCode :: Ok => {
121
+ let token_response: EntraIdTokenResponse =
122
+ deserialize ( CLIENT_ASSERTION_CREDENTIAL , res) . await ?;
123
+ Ok ( AccessToken :: new (
124
+ token_response. access_token ,
125
+ OffsetDateTime :: now_utc ( ) + Duration :: from_secs ( token_response. expires_in ) ,
126
+ ) )
127
+ }
128
+ _ => {
129
+ let error_response: EntraIdErrorResponse =
130
+ deserialize ( CLIENT_ASSERTION_CREDENTIAL , res) . await ?;
131
+ let message = if error_response. error_description . is_empty ( ) {
132
+ format ! ( "{} authentication failed." , CLIENT_ASSERTION_CREDENTIAL )
133
+ } else {
134
+ format ! (
135
+ "{} authentication failed. {}" ,
136
+ CLIENT_ASSERTION_CREDENTIAL , error_response. error_description
137
+ )
138
+ } ;
139
+ Err ( Error :: message ( ErrorKind :: Credential , message) )
140
+ }
141
+ }
100
142
}
101
143
}
102
144
103
145
#[ cfg_attr( target_arch = "wasm32" , async_trait:: async_trait( ?Send ) ) ]
104
146
#[ cfg_attr( not( target_arch = "wasm32" ) , async_trait:: async_trait) ]
105
147
impl < C : ClientAssertion > TokenCredential for ClientAssertionCredential < C > {
106
148
async fn get_token ( & self , scopes : & [ & str ] ) -> azure_core:: Result < AccessToken > {
107
- self . cache . get_token ( scopes, self . get_token ( scopes) ) . await
149
+ self . cache
150
+ . get_token ( scopes, self . get_token_impl ( scopes) )
151
+ . await
108
152
}
109
153
}
110
154
111
155
#[ cfg( test) ]
112
156
pub ( crate ) mod tests {
113
- use std:: collections:: HashMap ;
114
-
115
157
use super :: * ;
116
158
use crate :: tests:: * ;
117
159
use azure_core:: {
118
160
authority_hosts:: AZURE_PUBLIC_CLOUD ,
119
161
http:: {
120
- headers:: { self , content_type} ,
121
- Body , Method , Request ,
162
+ headers:: { self , content_type, Headers } ,
163
+ Body , Method , Request , Response ,
122
164
} ,
165
+ Bytes ,
123
166
} ;
167
+ use std:: { collections:: HashMap , time:: SystemTime } ;
168
+ use time:: UtcOffset ;
124
169
use url:: form_urlencoded;
125
170
126
171
pub const FAKE_ASSERTION : & str = "fake assertion" ;
@@ -140,10 +185,7 @@ pub(crate) mod tests {
140
185
) ;
141
186
let expected_params = [
142
187
( "client_assertion" , FAKE_ASSERTION ) ,
143
- (
144
- "client_assertion_type" ,
145
- "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" ,
146
- ) ,
188
+ ( "client_assertion_type" , ASSERTION_TYPE ) ,
147
189
( "client_id" , FAKE_CLIENT_ID ) ,
148
190
( "grant_type" , "client_credentials" ) ,
149
191
( "scope" , & LIVE_TEST_SCOPES . join ( " " ) ) ,
@@ -166,4 +208,96 @@ pub(crate) mod tests {
166
208
Ok ( ( ) )
167
209
}
168
210
}
211
+
212
+ #[ derive( Debug ) ]
213
+ struct MockAssertion { }
214
+
215
+ #[ cfg_attr( target_arch = "wasm32" , async_trait:: async_trait( ?Send ) ) ]
216
+ #[ cfg_attr( not( target_arch = "wasm32" ) , async_trait:: async_trait) ]
217
+ impl ClientAssertion for MockAssertion {
218
+ async fn secret ( & self ) -> azure_core:: Result < String > {
219
+ Ok ( FAKE_ASSERTION . to_string ( ) )
220
+ }
221
+ }
222
+
223
+ #[ tokio:: test]
224
+ async fn get_token_error ( ) {
225
+ let expected = "error description from the response" ;
226
+ let mock = MockSts :: new (
227
+ vec ! [ Response :: from_bytes(
228
+ StatusCode :: BadRequest ,
229
+ Headers :: default ( ) ,
230
+ Bytes :: from( format!(
231
+ r#"{{"error":"invalid_request","error_description":"{}","error_codes":[50027],"timestamp":"2025-04-18 16:04:37Z","trace_id":"...","correlation_id":"...","error_uri":"https://login.microsoftonline.com/error?code=50027"}}"# ,
232
+ expected
233
+ ) ) ,
234
+ ) ] ,
235
+ Some ( Arc :: new ( is_valid_request ( ) ) ) ,
236
+ ) ;
237
+ let credential = ClientAssertionCredential :: new (
238
+ FAKE_TENANT_ID . to_string ( ) ,
239
+ FAKE_CLIENT_ID . to_string ( ) ,
240
+ MockAssertion { } ,
241
+ Some ( ClientAssertionCredentialOptions {
242
+ credential_options : TokenCredentialOptions {
243
+ http_client : Arc :: new ( mock) ,
244
+ ..Default :: default ( )
245
+ } ,
246
+ ..Default :: default ( )
247
+ } ) ,
248
+ )
249
+ . expect ( "valid credential" ) ;
250
+
251
+ let error = credential
252
+ . get_token ( LIVE_TEST_SCOPES )
253
+ . await
254
+ . expect_err ( "authentication error" ) ;
255
+ assert ! ( matches!( error. kind( ) , ErrorKind :: Credential ) ) ;
256
+ assert ! (
257
+ error. to_string( ) . contains( expected) ,
258
+ "expected error description from the response, got '{}'" ,
259
+ error
260
+ ) ;
261
+ }
262
+
263
+ #[ tokio:: test]
264
+ async fn get_token_success ( ) {
265
+ let mock = MockSts :: new (
266
+ vec ! [ Response :: from_bytes(
267
+ StatusCode :: Ok ,
268
+ Headers :: default ( ) ,
269
+ Bytes :: from( format!(
270
+ r#"{{"access_token":"{}","expires_in":3600,"token_type":"Bearer"}}"# ,
271
+ FAKE_TOKEN
272
+ ) ) ,
273
+ ) ] ,
274
+ Some ( Arc :: new ( is_valid_request ( ) ) ) ,
275
+ ) ;
276
+ let credential = ClientAssertionCredential :: new (
277
+ FAKE_TENANT_ID . to_string ( ) ,
278
+ FAKE_CLIENT_ID . to_string ( ) ,
279
+ MockAssertion { } ,
280
+ Some ( ClientAssertionCredentialOptions {
281
+ credential_options : TokenCredentialOptions {
282
+ http_client : Arc :: new ( mock) ,
283
+ ..Default :: default ( )
284
+ } ,
285
+ ..Default :: default ( )
286
+ } ) ,
287
+ )
288
+ . expect ( "valid credential" ) ;
289
+
290
+ let token = credential. get_token ( LIVE_TEST_SCOPES ) . await . expect ( "token" ) ;
291
+ assert_eq ! ( FAKE_TOKEN , token. token. secret( ) ) ;
292
+ assert ! ( token. expires_on > SystemTime :: now( ) ) ;
293
+ assert_eq ! ( UtcOffset :: UTC , token. expires_on. offset( ) ) ;
294
+
295
+ // MockSts will return an error if the credential sends another request
296
+ let cached_token = credential
297
+ . get_token ( LIVE_TEST_SCOPES )
298
+ . await
299
+ . expect ( "cached token" ) ;
300
+ assert_eq ! ( token. token. secret( ) , cached_token. token. secret( ) ) ;
301
+ assert_eq ! ( token. expires_on, cached_token. expires_on) ;
302
+ }
169
303
}
0 commit comments