Skip to content

Commit 1a04744

Browse files
committed
Use oidc.TokenClaims as a superset of jwt.Claims for azure claims.
1 parent 114e04c commit 1a04744

File tree

3 files changed

+47
-37
lines changed

3 files changed

+47
-37
lines changed

lib/auth/join_azure.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ type attestedData struct {
8888

8989
type accessTokenClaims struct {
9090
oidc.TokenClaims
91-
jwt.Claims
9291
TenantID string `json:"tid"`
9392
Version string `json:"ver"`
9493

@@ -108,6 +107,18 @@ type accessTokenClaims struct {
108107
AzureResourceID string `json:"xms_az_rid"`
109108
}
110109

110+
func (c *accessTokenClaims) AsJWTClaims() jwt.Claims {
111+
return jwt.Claims{
112+
Issuer: c.Issuer,
113+
Subject: c.Subject,
114+
Audience: jwt.Audience(c.Audience),
115+
Expiry: jwt.NewNumericDate(c.Expiration.AsTime()),
116+
NotBefore: jwt.NewNumericDate(c.NotBefore.AsTime()),
117+
IssuedAt: jwt.NewNumericDate(c.IssuedAt.AsTime()),
118+
ID: c.JWTID,
119+
}
120+
}
121+
111122
type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error)
112123

113124
type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error)
@@ -259,7 +270,7 @@ func verifyVMIdentity(
259270
Time: requestStart,
260271
}
261272

262-
if err := tokenClaims.Validate(expectedClaims); err != nil {
273+
if err := tokenClaims.AsJWTClaims().Validate(expectedClaims); err != nil {
263274
return nil, trace.Wrap(err)
264275
}
265276

@@ -282,7 +293,7 @@ func verifyVMIdentity(
282293

283294
tokenCredential := azure.NewStaticCredential(azcore.AccessToken{
284295
Token: accessToken,
285-
ExpiresOn: tokenClaims.Expiry.Time(),
296+
ExpiresOn: tokenClaims.GetExpiration(),
286297
})
287298
vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential)
288299
if err != nil {

lib/auth/join_azure_test.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/google/uuid"
3535
"github.com/gravitational/trace"
3636
"github.com/stretchr/testify/require"
37+
"github.com/zitadel/oidc/v3/pkg/oidc"
3738

3839
"github.com/gravitational/teleport/api/client/proto"
3940
"github.com/gravitational/teleport/api/types"
@@ -144,14 +145,14 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time
144145
return "", trace.Wrap(err)
145146
}
146147
claims := accessTokenClaims{
147-
Claims: jwt.Claims{
148-
Issuer: "https://sts.windows.net/test-tenant-id/",
149-
Audience: []string{azureAccessTokenAudience},
150-
Subject: "test",
151-
IssuedAt: jwt.NewNumericDate(issueTime),
152-
NotBefore: jwt.NewNumericDate(issueTime),
153-
Expiry: jwt.NewNumericDate(issueTime.Add(time.Minute)),
154-
ID: "id",
148+
TokenClaims: oidc.TokenClaims{
149+
Issuer: "https://sts.windows.net/test-tenant-id/",
150+
Audience: []string{azureAccessTokenAudience},
151+
Subject: "test",
152+
IssuedAt: oidc.FromTime(issueTime),
153+
NotBefore: oidc.FromTime(issueTime),
154+
Expiration: oidc.FromTime(issueTime.Add(time.Minute)),
155+
JWTID: "id",
155156
},
156157
ManangedIdentityResourceID: managedIdentityResourceID,
157158
AzureResourceID: azureResourceID,

lib/githubactions/token_validator_test.go

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,7 @@ func testSigner(t *testing.T) ([]byte, jose.Signer) {
381381
return jwksData, signer
382382
}
383383

384-
//nolint:govet // there's some weird json struct tag overlap here
385384
type claims struct {
386-
jwt.Claims
387385
IDTokenClaims
388386
Subject string `json:"sub"`
389387
}
@@ -409,14 +407,14 @@ func TestValidateTokenWithJWKS(t *testing.T) {
409407
claims: claims{
410408
IDTokenClaims: IDTokenClaims{
411409
Repository: "123",
410+
TokenClaims: oidc.TokenClaims{
411+
Audience: oidc.Audience{clusterName},
412+
IssuedAt: oidc.FromTime(now.Add(-1 * time.Minute)),
413+
NotBefore: oidc.FromTime(now.Add(-1 * time.Minute)),
414+
Expiration: oidc.FromTime(now.Add(10 * time.Minute)),
415+
},
412416
},
413417
Subject: "foo",
414-
Claims: jwt.Claims{
415-
Audience: jwt.Audience{clusterName},
416-
IssuedAt: jwt.NewNumericDate(now.Add(-1 * time.Minute)),
417-
NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Minute)),
418-
Expiry: jwt.NewNumericDate(now.Add(10 * time.Minute)),
419-
},
420418
},
421419
wantResult: &IDTokenClaims{
422420
Sub: "foo",
@@ -429,14 +427,14 @@ func TestValidateTokenWithJWKS(t *testing.T) {
429427
claims: claims{
430428
IDTokenClaims: IDTokenClaims{
431429
Repository: "123",
430+
TokenClaims: oidc.TokenClaims{
431+
Audience: oidc.Audience{clusterName},
432+
IssuedAt: oidc.FromTime(now.Add(-1 * time.Minute)),
433+
NotBefore: oidc.FromTime(now.Add(-1 * time.Minute)),
434+
Expiration: oidc.FromTime(now.Add(10 * time.Minute)),
435+
},
432436
},
433437
Subject: "foo",
434-
Claims: jwt.Claims{
435-
Audience: jwt.Audience{clusterName},
436-
IssuedAt: jwt.NewNumericDate(now.Add(-1 * time.Minute)),
437-
NotBefore: jwt.NewNumericDate(now.Add(-1 * time.Minute)),
438-
Expiry: jwt.NewNumericDate(now.Add(10 * time.Minute)),
439-
},
440438
},
441439
wantResult: &IDTokenClaims{
442440
Sub: "foo",
@@ -450,14 +448,14 @@ func TestValidateTokenWithJWKS(t *testing.T) {
450448
claims: claims{
451449
IDTokenClaims: IDTokenClaims{
452450
Repository: "123",
451+
TokenClaims: oidc.TokenClaims{
452+
Audience: oidc.Audience{clusterName},
453+
IssuedAt: oidc.FromTime(now.Add(-2 * time.Minute)),
454+
NotBefore: oidc.FromTime(now.Add(-2 * time.Minute)),
455+
Expiration: oidc.FromTime(now.Add(-1 * time.Minute)),
456+
},
453457
},
454458
Subject: "foo",
455-
Claims: jwt.Claims{
456-
Audience: jwt.Audience{clusterName},
457-
IssuedAt: jwt.NewNumericDate(now.Add(-2 * time.Minute)),
458-
NotBefore: jwt.NewNumericDate(now.Add(-2 * time.Minute)),
459-
Expiry: jwt.NewNumericDate(now.Add(-1 * time.Minute)),
460-
},
461459
},
462460
wantErr: "token is expired",
463461
},
@@ -467,14 +465,14 @@ func TestValidateTokenWithJWKS(t *testing.T) {
467465
claims: claims{
468466
IDTokenClaims: IDTokenClaims{
469467
Repository: "123",
468+
TokenClaims: oidc.TokenClaims{
469+
Audience: oidc.Audience{clusterName},
470+
IssuedAt: oidc.FromTime(now.Add(2 * time.Minute)),
471+
NotBefore: oidc.FromTime(now.Add(2 * time.Minute)),
472+
Expiration: oidc.FromTime(now.Add(4 * time.Minute)),
473+
},
470474
},
471475
Subject: "foo",
472-
Claims: jwt.Claims{
473-
Audience: jwt.Audience{clusterName},
474-
IssuedAt: jwt.NewNumericDate(now.Add(2 * time.Minute)),
475-
NotBefore: jwt.NewNumericDate(now.Add(2 * time.Minute)),
476-
Expiry: jwt.NewNumericDate(now.Add(4 * time.Minute)),
477-
},
478476
},
479477
wantErr: "token not valid yet",
480478
},

0 commit comments

Comments
 (0)