@@ -34,17 +34,17 @@ import (
3434 "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
3535 armpolicy "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/policy"
3636 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
37- "github.com/coreos/go-oidc/v3/oidc"
3837 "github.com/digitorus/pkcs7"
3938 "github.com/go-jose/go-jose/v3/jwt"
4039 "github.com/gravitational/trace"
41- "github.com/jonboulle/clockwork "
40+ "github.com/zitadel/oidc/v3/pkg/oidc "
4241
4342 "github.com/gravitational/teleport/api/client"
4443 "github.com/gravitational/teleport/api/client/proto"
4544 workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
4645 "github.com/gravitational/teleport/api/types"
4746 "github.com/gravitational/teleport/lib/cloud/azure"
47+ liboidc "github.com/gravitational/teleport/lib/oidc"
4848 "github.com/gravitational/teleport/lib/utils"
4949)
5050
@@ -87,7 +87,7 @@ type attestedData struct {
8787}
8888
8989type accessTokenClaims struct {
90- jwt. Claims
90+ oidc. TokenClaims
9191 TenantID string `json:"tid"`
9292 Version string `json:"ver"`
9393
@@ -107,18 +107,29 @@ type accessTokenClaims struct {
107107 AzureResourceID string `json:"xms_az_rid"`
108108}
109109
110+ func (c * accessTokenClaims ) AsJWTClaims () jwt.Claims {
111+ return jwt.Claims {
112+ Issuer : c .Issuer ,
113+ Subject : c .Subject ,
114+ Audience : jwt .Audience (c .Audience ),
115+ Expiry : jwt .NewNumericDate (c .Expiration .AsTime ()),
116+ NotBefore : jwt .NewNumericDate (c .NotBefore .AsTime ()),
117+ IssuedAt : jwt .NewNumericDate (c .IssuedAt .AsTime ()),
118+ ID : c .JWTID ,
119+ }
120+ }
121+
110122type azureVerifyTokenFunc func (ctx context.Context , rawIDToken string ) (* accessTokenClaims , error )
111123
112124type vmClientGetter func (subscriptionID string , token * azure.StaticCredential ) (azure.VirtualMachinesClient , error )
113125
114126type azureRegisterConfig struct {
115- clock clockwork.Clock
116127 certificateAuthorities []* x509.Certificate
117128 verify azureVerifyTokenFunc
118129 getVMClient vmClientGetter
119130}
120131
121- func azureVerifyFuncFromOIDCVerifier (cfg * oidc. Config ) azureVerifyTokenFunc {
132+ func azureVerifyFuncFromOIDCVerifier (clientID string ) azureVerifyTokenFunc {
122133 return func (ctx context.Context , rawIDToken string ) (* accessTokenClaims , error ) {
123134 token , err := jwt .ParseSigned (rawIDToken )
124135 if err != nil {
@@ -133,32 +144,13 @@ func azureVerifyFuncFromOIDCVerifier(cfg *oidc.Config) azureVerifyTokenFunc {
133144 if err != nil {
134145 return nil , trace .Wrap (err )
135146 }
136- provider , err := oidc .NewProvider (ctx , issuer )
137- if err != nil {
138- return nil , trace .Wrap (err )
139- }
140- verifiedToken , err := provider .Verifier (cfg ).Verify (ctx , rawIDToken )
141- if err != nil {
142- return nil , trace .Wrap (err )
143- }
144- var tokenClaims accessTokenClaims
145- if err := verifiedToken .Claims (& tokenClaims ); err != nil {
146- return nil , trace .Wrap (err )
147- }
148- return & tokenClaims , nil
147+ return liboidc .ValidateToken [* accessTokenClaims ](ctx , issuer , clientID , rawIDToken )
149148 }
150149}
151150
152151func (cfg * azureRegisterConfig ) CheckAndSetDefaults (ctx context.Context ) error {
153- if cfg .clock == nil {
154- cfg .clock = clockwork .NewRealClock ()
155- }
156152 if cfg .verify == nil {
157- oidcConfig := & oidc.Config {
158- ClientID : azureAccessTokenAudience ,
159- Now : cfg .clock .Now ,
160- }
161- cfg .verify = azureVerifyFuncFromOIDCVerifier (oidcConfig )
153+ cfg .verify = azureVerifyFuncFromOIDCVerifier (azureAccessTokenAudience )
162154 }
163155
164156 if cfg .certificateAuthorities == nil {
@@ -278,7 +270,7 @@ func verifyVMIdentity(
278270 Time : requestStart ,
279271 }
280272
281- if err := tokenClaims .Validate (expectedClaims ); err != nil {
273+ if err := tokenClaims .AsJWTClaims (). Validate (expectedClaims ); err != nil {
282274 return nil , trace .Wrap (err )
283275 }
284276
@@ -301,7 +293,7 @@ func verifyVMIdentity(
301293
302294 tokenCredential := azure .NewStaticCredential (azcore.AccessToken {
303295 Token : accessToken ,
304- ExpiresOn : tokenClaims .Expiry . Time (),
296+ ExpiresOn : tokenClaims .GetExpiration (),
305297 })
306298 vmClient , err := cfg .getVMClient (subscriptionID , tokenCredential )
307299 if err != nil {
0 commit comments